Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

openai/sparse_autoencoder

Repository files navigation

Sparse autoencoders

This repository hosts:

  • sparse autoencoders trained on the GPT2-small model's activations.
  • a visualizer for the autoencoders' features

Install

pip install git+https://github.com/openai/sparse_autoencoder.git

Code structure

See sae-viewer to see the visualizer code, hosted publicly here.

See model.py for details on the autoencoder model architecture. See train.py for autoencoder training code. See paths.py for more details on the available autoencoders.

Example usage

import torch
import blobfile as bf
import transformer_lens
import sparse_autoencoder
# Extract neuron activations with transformer_lens
model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
device = next(model.parameters()).device
prompt = "This is an example of a prompt that"
tokens = model.to_tokens(prompt) # (1, n_tokens)
with torch.no_grad():
 logits, activation_cache = model.run_with_cache(tokens, remove_batch_dim=True)
layer_index = 6
location = "resid_post_mlp"
transformer_lens_loc = {
 "mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
 "resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
 "resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
 "resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
 "resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
}[location]
with bf.BlobFile(sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb") as f:
 state_dict = torch.load(f)
 autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
 autoencoder.to(device)
input_tensor = activation_cache[transformer_lens_loc]
input_tensor_ln = input_tensor
with torch.no_grad():
 latent_activations, info = autoencoder.encode(input_tensor_ln)
 reconstructed_activations = autoencoder.decode(latent_activations, info)
normalized_mse = (reconstructed_activations - input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
print(location, normalized_mse)

About

No description, website, or topics provided.

Resources

License

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

AltStyle によって変換されたページ (->オリジナル) /