hannahcyberey's picture
Change to local inference
40a29d6 verified
import os, warnings
from operator import attrgetter
from typing import List, Dict
import torch
import torch.nn.functional as F
from torchtyping import TensorType
from transformers import TextIteratorStreamer
from transformers import AutoTokenizer, BatchEncoding
import nnsight
from nnsight import LanguageModel
from nnsight.intervention import Envoy
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# nnsight with multi-threading: https://github.com/ndif-team/nnsight/issues/280
nnsight.CONFIG.APP.GLOBAL_TRACING = False
config = {
"model_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
"steering_vec": "activations/candidate_vectors.pt",
"offset": "activations/offsets.pt",
}
def detect_module_attrs(model: LanguageModel) -> str:
if "model" in model._modules and "layers" in model.model._modules:
return "model.layers"
elif "transformers" in model._modules and "h" in model.transformers._modules:
return "transformers.h"
else:
raise Exception("Failed to detect module attributes.")
class ModelBase:
def __init__(
self, model_name: str,
steering_vecs: TensorType, offsets: TensorType,
tokenizer: AutoTokenizer = None, block_module_attr=None
):
if tokenizer is None:
self.tokenizer = self._load_tokenizer(model_name)
else:
self.tokenizer = tokenizer
self.model = self._load_model(model_name, self.tokenizer)
self.device = self.model.device
self.hidden_size = self.model.config.hidden_size
if block_module_attr is None:
self.block_modules = self.get_module(detect_module_attrs(self.model))
else:
self.block_modules = self.get_module(block_module_attr)
self.steering_vecs = F.normalize(steering_vecs, dim=-1)
self.steering_vecs, self.offsets = self.set_dtype(self.steering_vecs, offsets)
def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel:
return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
def _load_tokenizer(self, model_name) -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = tokenizer.chat_template.replace("<|Assistant|><think>\\n", "<|Assistant|><think>")
return tokenizer
def tokenize(self, prompt: str) -> BatchEncoding:
return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt")
def get_module(self, attr: str) -> Envoy:
return attrgetter(attr)(self.model)
def set_dtype(self, *vars):
if len(vars) == 1:
return vars[0].to(self.model.dtype)
else:
return (var.to(self.model.dtype) for var in vars)
def apply_chat_template(self, instruction: str) -> List[str]:
messages = [{"role": "user", "content": instruction}]
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def run_generation(self, inputs, streamer: TextIteratorStreamer, generation_config: Dict):
inputs = inputs.to(self.device)
_ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config)
def steer_generation(
self, inputs, streamer: TextIteratorStreamer, k: float,
layer: int, coeff: float, generation_config: Dict
):
layer_block = self.block_modules[layer]
unit_vec = self.steering_vecs[layer]
offset = self.offsets[layer]
with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config):
with self.block_modules.all():
acts = layer_block.output[0].clone()
proj = (acts - offset) @ unit_vec.unsqueeze(-1) * unit_vec
layer_block.output[0][:] = acts - proj + coeff * k * unit_vec
def load_model() -> ModelBase:
steering_vecs = torch.load(config['steering_vec'], weights_only=True)
offsets = torch.load(config['offset'], weights_only=True)
model = ModelBase(config['model_name'], steering_vecs=steering_vecs, offsets=offsets)
return model