|
|
import os, warnings |
|
|
from operator import attrgetter |
|
|
from typing import List, Dict |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torchtyping import TensorType |
|
|
from transformers import TextIteratorStreamer |
|
|
from transformers import AutoTokenizer, BatchEncoding |
|
|
import nnsight |
|
|
from nnsight import LanguageModel |
|
|
from nnsight.intervention import Envoy |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
nnsight.CONFIG.APP.GLOBAL_TRACING = False |
|
|
|
|
|
config = { |
|
|
"model_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", |
|
|
"steering_vec": "activations/candidate_vectors.pt", |
|
|
"offset": "activations/offsets.pt", |
|
|
} |
|
|
|
|
|
def detect_module_attrs(model: LanguageModel) -> str: |
|
|
if "model" in model._modules and "layers" in model.model._modules: |
|
|
return "model.layers" |
|
|
elif "transformers" in model._modules and "h" in model.transformers._modules: |
|
|
return "transformers.h" |
|
|
else: |
|
|
raise Exception("Failed to detect module attributes.") |
|
|
|
|
|
|
|
|
class ModelBase: |
|
|
def __init__( |
|
|
self, model_name: str, |
|
|
steering_vecs: TensorType, offsets: TensorType, |
|
|
tokenizer: AutoTokenizer = None, block_module_attr=None |
|
|
): |
|
|
if tokenizer is None: |
|
|
self.tokenizer = self._load_tokenizer(model_name) |
|
|
else: |
|
|
self.tokenizer = tokenizer |
|
|
self.model = self._load_model(model_name, self.tokenizer) |
|
|
|
|
|
self.device = self.model.device |
|
|
self.hidden_size = self.model.config.hidden_size |
|
|
if block_module_attr is None: |
|
|
self.block_modules = self.get_module(detect_module_attrs(self.model)) |
|
|
else: |
|
|
self.block_modules = self.get_module(block_module_attr) |
|
|
|
|
|
self.steering_vecs = F.normalize(steering_vecs, dim=-1) |
|
|
self.steering_vecs, self.offsets = self.set_dtype(self.steering_vecs, offsets) |
|
|
|
|
|
def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel: |
|
|
return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) |
|
|
|
|
|
def _load_tokenizer(self, model_name) -> AutoTokenizer: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
tokenizer.padding_side = "left" |
|
|
if not tokenizer.pad_token: |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
tokenizer.chat_template = tokenizer.chat_template.replace("<|Assistant|><think>\\n", "<|Assistant|><think>") |
|
|
return tokenizer |
|
|
|
|
|
def tokenize(self, prompt: str) -> BatchEncoding: |
|
|
return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt") |
|
|
|
|
|
def get_module(self, attr: str) -> Envoy: |
|
|
return attrgetter(attr)(self.model) |
|
|
|
|
|
def set_dtype(self, *vars): |
|
|
if len(vars) == 1: |
|
|
return vars[0].to(self.model.dtype) |
|
|
else: |
|
|
return (var.to(self.model.dtype) for var in vars) |
|
|
|
|
|
def apply_chat_template(self, instruction: str) -> List[str]: |
|
|
messages = [{"role": "user", "content": instruction}] |
|
|
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
def run_generation(self, inputs, streamer: TextIteratorStreamer, generation_config: Dict): |
|
|
inputs = inputs.to(self.device) |
|
|
_ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config) |
|
|
|
|
|
def steer_generation( |
|
|
self, inputs, streamer: TextIteratorStreamer, k: float, |
|
|
layer: int, coeff: float, generation_config: Dict |
|
|
): |
|
|
layer_block = self.block_modules[layer] |
|
|
unit_vec = self.steering_vecs[layer] |
|
|
offset = self.offsets[layer] |
|
|
|
|
|
with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config): |
|
|
with self.block_modules.all(): |
|
|
acts = layer_block.output[0].clone() |
|
|
proj = (acts - offset) @ unit_vec.unsqueeze(-1) * unit_vec |
|
|
layer_block.output[0][:] = acts - proj + coeff * k * unit_vec |
|
|
|
|
|
|
|
|
def load_model() -> ModelBase: |
|
|
steering_vecs = torch.load(config['steering_vec'], weights_only=True) |
|
|
offsets = torch.load(config['offset'], weights_only=True) |
|
|
model = ModelBase(config['model_name'], steering_vecs=steering_vecs, offsets=offsets) |
|
|
return model |
|
|
|
|
|
|