Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easily evaluate models steered by SAEs #2641

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/dog_steer.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
latent_idx,steering_coefficient,sae_release,sae_id,description
12082, 240.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical,this feature has been found on neuronpedia to make the model talk about dogs and obedience
1 change: 1 addition & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
openai_completions,
optimum_ipex,
optimum_lm,
sae_steered_beta,
textsynth,
vllm_causallms,
vllm_vlms,
Expand Down
27 changes: 27 additions & 0 deletions lm_eval/models/add_to_sae_steered_beta_then_delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import einops
# Andre was working on Matthew's folders, and Matthew didn't want to edit the same doc at the same time.
def steering_hook_projection(
activations,#: Float[Tensor], # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
hook: HookPoint,
sae: SAE,
latent_idx: int,
steering_coefficient: float,
) -> Tensor:
"""
Steers the model by finding the projection of each activations,
along the specified feature and adding some multiple of that projection to the activation.
"""
bad_feature = sae.W_dec[latent_idx] # batch, pos, d_in @ d_in, d_embedding -> batch, pos, d_embedding
dot_products = einops.einsum(activations, bad_feature, "batch pos d_embedding, d_embedding -> batch pos")
dot_products /= bad_feature.norm()

# Calculate the projection of activations onto the feature direction
projection = einops.einsum(
dot_products,
bad_feature,
"batch pos, d_embedding -> batch pos d_embedding"
)

# Add scaled projection to original activations
return activations + steering_coefficient * projection

57 changes: 57 additions & 0 deletions lm_eval/models/projection_deleteme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch

def batch_vector_projection(vectors: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Projects each vector in a batch onto a target vector.

Args:
vectors: Tensor of shape (b, p, d) where:
b is the batch size
p is the number of vectors per batch
d is the dimension of each vector
target: Tensor of shape (d,) - the vector to project onto

Returns:
Tensor of shape (b, p, d) containing the projected vectors

Example:
b, p, d = 32, 10, 3 # batch of 32, 10 vectors each, in 3D
vectors = torch.randn(b, p, d)
target = torch.randn(d)
projections = batch_vector_projection(vectors, target)
"""
# Ensure target is unit vector
target = torch.nn.functional.normalize(target, dim=0)

# Reshape target to (1, 1, d) for broadcasting
target_reshaped = target.view(1, 1, -1)

# Compute dot product between each vector and target
# Result shape: (b, p, 1)
dot_products = torch.sum(vectors * target_reshaped, dim=-1, keepdim=True)

# Project each vector onto target
# Multiply dot products by target vector
# Result shape: (b, p, d)
projections = dot_products * target_reshaped

return projections, dot_products

# Test function
if __name__ == "__main__":
# Create sample data
batch_size, vectors_per_batch, dim = 2, 3, 4
vectors = torch.randn(batch_size, vectors_per_batch, dim)
target = torch.randn(dim)

# Compute projections
projected, dot_products = batch_vector_projection(vectors, target)

_, zero_dot_products = batch_vector_projection(vectors - projected, target)
assert torch.allclose(zero_dot_products, torch.zeros_like(zero_dot_products), atol=1e-6)
print("Without proj, close to zero")
# Verify shapes
print(f"Input shape: {vectors.shape}")
print(f"Target shape: {target.shape}")
print(f"Output shape: {projected.shape}")

212 changes: 212 additions & 0 deletions lm_eval/models/sae_steered_beta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Credit: contributed by https://github.com/AMindToThink aka Matthew Khoriaty of Northwestern University."""

from functools import partial

import torch
from jaxtyping import Float
from sae_lens import SAE, HookedSAETransformer
from torch import Tensor
from transformer_lens import loading_from_pretrained
from transformer_lens.hook_points import HookPoint

from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM

def steering_hook_add_scaled_one_hot(
activations,#: Float[Tensor], # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
hook: HookPoint,
sae: SAE,
latent_idx: int,
steering_coefficient: float,
) -> Tensor:
"""
Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
sequence positions.
"""
return activations + steering_coefficient * sae.W_dec[latent_idx]

# def steering_hook_clamp(
# activations,#: Float[Tensor], # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
# hook: HookPoint,
# sae: SAE,
# latent_idx: int,
# steering_coefficient: float,
# ) -> Tensor:
# """
# Steers the model by returning a modified activations tensor, with some multiple of the steering vector added to all
# sequence positions.
# """
# raise NotImplemented
# z = sae.encode(activations)
# z[latent_idx] = steering_coefficient
# return sae.decode(activations)
# return activations + steering_coefficient * sae.W_dec[latent_idx]


def clamp_sae_feature(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.

Args:
sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
hook (HookPoint): The transformer-lens hook point
latent_idx (int): Index of the latent feature to clamp
value (float): Value to clamp the feature to

Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""

sae_acts[:, :, latent_idx] = value

return sae_acts

def clamp_original(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.

Args:
sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
hook (HookPoint): The transformer-lens hook point
latent_idx (int): Index of the latent feature to clamp
value (float): Value to clamp the feature to

Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""

mask = sae_acts[:, :, latent_idx] > 0 # Create a boolean mask where values are greater than 0
sae_acts[:, :, latent_idx][mask] = value # Replace values conditionally

return sae_acts

def print_sae_acts(sae_acts:Tensor, hook:HookPoint) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.

Args:
sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
hook (HookPoint): The transformer-lens hook point
latent_idx (int): Index of the latent feature to clamp
value (float): Value to clamp the feature to

Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""
print(40*"----")
print(f"This is the latent activations of {hook.name}")
print(sae_acts.shape)
print(torch.all(sae_acts > 0))
return sae_acts

string_to_steering_function_dict : dict = {'add':steering_hook_add_scaled_one_hot, 'clamp':clamp_sae_feature}

class InterventionModel(HookedSAETransformer): # Replace with the specific model class
def __init__(self, base_name: str, device: str = "cuda:0", model=None):
trueconfig = loading_from_pretrained.get_pretrained_model_config(
base_name, device=device
)
super().__init__(trueconfig)
self.model = model or HookedSAETransformer.from_pretrained(base_name, device=device)
self.model.use_error_term = True
self.model.eval()
self.device = device # Add device attribute
self.to(device) # Ensure model is on the correct device

@classmethod
def from_csv(
cls, csv_path: str, base_name: str, device: str = "cuda:0"
) -> "InterventionModel":
"""
Create an InterventionModel from a CSV file containing steering configurations.

Expected CSV format:
index, coefficient, sae_release, sae_id, description
12082, 240.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical, increase dogs
...

Args:
csv_path: Path to the CSV file containing steering configurations
device: Device to place the model on

Returns:
InterventionModel with configured steering hooks
"""
import pandas as pd
model = HookedSAETransformer.from_pretrained(base_name, device=device)
# Read steering configurations
df = pd.read_csv(csv_path)
# Create hooks for each row in the CSV
sae_cache = {}
hooks = []

def get_sae(sae_release, sae_id):
cache_key = (sae_release, sae_id)
if cache_key not in sae_cache:
sae_cache[cache_key] = SAE.from_pretrained(
sae_release, sae_id, device=str(device)
)[0]
return sae_cache[cache_key]

for _, row in df.iterrows():
sae_release = row["sae_release"]
sae_id = row["sae_id"]
latent_idx = int(row["latent_idx"])
steering_coefficient = float(row["steering_coefficient"])
sae = get_sae(sae_release=sae_release, sae_id=sae_id)
sae.use_error_term = True
sae.eval()
model.add_sae(sae)
hook_action = row.get("hook_action", "add")
if hook_action == "add":
hook_name = f"{sae.cfg.hook_name}.hook_sae_input" # we aren't actually putting the input through the model
hook = partial(steering_hook_add_scaled_one_hot,
sae=sae,
latent_idx=latent_idx,
steering_coefficient=steering_coefficient,
)
model.add_hook(hook_name, hook)
elif hook_action == "clamp":
sae.add_hook("hook_sae_acts_post", partial(clamp_original, latent_idx=latent_idx, value=steering_coefficient))
elif hook_action == 'print':
sae.add_hook("hook_sae_acts_post", partial(print_sae_acts))
else:
raise ValueError(f"Unknown hook type: {hook_action}")



# Create and return the model
return cls(base_name=base_name, device=device, model=model)

def forward(self, *args, **kwargs):
# Handle both input_ids and direct tensor inputs
if "input_ids" in kwargs:
input_tensor = kwargs.pop("input_ids") # Use pop to remove it
elif args:
input_tensor = args[0]
args = args[1:] # Remove the first argument
else:
input_tensor = None
with torch.no_grad(): # I don't know why this no grad is necessary; I tried putting everything into eval mode. And yet, this is necessary to prevent CUDA out of memory exceptions.
output = self.model.forward(input_tensor, *args, **kwargs)
return output


@register_model("sae_steered_beta")
class InterventionModelLM(HFLM):
def __init__(self, base_name, csv_path, **kwargs):
self.swap_in_model = InterventionModel.from_csv(
csv_path=csv_path, base_name=base_name, device=kwargs.get("device", "cuda")
)
self.swap_in_model.eval()
# Initialize other necessary attributes
super().__init__(pretrained=base_name, **kwargs)
if hasattr(self, "_model"):
# Delete all the model's parameters but keep the object
for param in self._model.parameters():
param.data.zero_()
param.requires_grad = False
# Remove all model modules while keeping the base object
for name, module in list(self._model.named_children()):
delattr(self._model, name)
torch.cuda.empty_cache()

def _model_call(self, inputs):
return self.swap_in_model.forward(inputs)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,21 @@ dependencies = [
"evaluate",
"datasets>=2.16.0",
"evaluate>=0.4.0",
"jaxtyping",
"jsonlines",
"numexpr",
"peft>=0.2.0",
"pybind11>=2.6.2",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu>=1.5.0",
"sae_lens",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.8",
"tqdm-multiprocess",
"transformers>=4.1",
"transformer-lens>=2.7.0",
"zstandard",
"dill",
"word2number",
Expand Down
Loading