인공지능/자연어 처리

SAE tutorials - SAE basic

이게될까 2024. 9. 22. 23:57
728x90
728x90

내일 완성형으로 작성하겠지만 여기서 pandas 버전 문제가 있더라고여

그러므로 pandas를 아래와 같이 버전 다운 시키면 됩니다.

pip install pandas==2.2.0

https://stackoverflow.com/questions/78524556/typeerror-cannot-convert-numpy-ndarray-to-numpy-ndarray

SAE Lens + Neuronpedia Tutorial

이 튜토리얼은 기계적 해석 가능성에서 인기 있는 새로운 기법인 희소 오토인코더(Sparse Autoencoders, SAEs)를 사용하여 신경망을 분석하는 방법에 대한 입문서입니다. 더 자세한 내용은 이 게시물을 참고하세요.

하지만 여기서는 SAE 특징이 무엇인지, SAELens에 SAEs를 로드하고 특징을 찾거나 식별하는 방법, 그리고 이를 사용하여 방향 조정(steering), 절제(ablation), 귀속(attribution)을 수행하는 방법을 설명할 것입니다.

이 튜토리얼에서는 다음 내용을 다룹니다:

  • SAEs에 대한 기본적인 소개
    • SAE Lens란 무엇인가?
    • 분석할 SAE 선택 및 SAE Lens로 로드하기.
    • SAE 클래스와 그 구성(config).
  • SAE 특징
    • 특징 대시보드란 무엇인가?
    • Neuronpedia에서 특징 대시보드 로드하기.
    • Autointerp 다운로드 및 설명을 통해 검색하기.
  • 특징 추론
    • HookedSAE Transformer 클래스를 사용하여 활성화를 특징으로 분해하기.
    • 관련된 프롬프트들 간에 특징 비교하기.
  • 특징 대시보드 만들기
    • 최대 활성화 예제
    • 특징 활성화 히스토그램
    • 로짓 가중치 분포
    • 확장: Not all language model features are linear 재현하기
  • SAE 기반 분석 방법 (고급)
    • SAE 특징을 사용한 모델 방향 조정
    • SAE 특징 절제
    • 회로 탐지를 위한 기울기 기반 귀속 방법

주의: 이 튜토리얼은 짧은 시간 내에 준비된 초기 초안으로, 앞으로 많은 개선이 이루어질 예정입니다. 그럼에도 불구하고, 이 초기 버전이 시작하려는 분들에게 유용하기를 바랍니다.

import os
from setproctitle import setproctitle

setproctitle("")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens sae-dashboard
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd

# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

사전 훈련된 희소 오토인코더 로드하기

첫 번째 단계로, 실제로 SAE를 로드해 보겠습니다! 하지만 그 전에, 어떤 SAE가 사용 가능한지 확인하는 것이 유용할 수 있습니다. 다음 코드는 SAELens에서 현재 사용 가능한 SAE 릴리스를 보여주며, 앞으로 더 많은 SAE를 추가할 때마다 계속 업데이트될 것입니다.

from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)

df # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model. 

실제로 SAEs는 일반적인 사용 사례에서 유용성이 다를 수 있습니다. 처음 시작할 때는 다음을 추천합니다:

  • Joseph의 오픈 소스 GPT2 Small Residual (gpt2-small-res-jb)
  • Joseph의 Feature Splitting (gpt2-small-res-jb-feature-splitting)
  • Gemma SAEs (gemma-2b-res-jb) (0,6) <- Neuronpedia에서 제공되며 좋습니다. (현재 12 / 17은 그렇게 좋지 않습니다).

다른 SAEs는 다양한 문제를 가지고 있습니다. 예를 들어, 너무 밀집되어 있거나 너무 밀집되지 않았거나, 특정 사용 사례를 위해 설계되었거나, 더 나은 버전이 될 것으로 기대되는 초기 초안인 경우가 있습니다. Decode Research와 Neuronpedia는 Neuronpedia에 있는 모든 SAEs를 SAE Lens에서 로드 가능하게 하고, 그 반대로도 가능하게 만드는 작업을 진행 중이며, 사람들이 작업할 SAEs를 선택할 수 있도록 공개 벤치마킹 통계도 제공할 예정입니다.

특정 릴리스에 포함된 모든 SAEs를 보려면 (모델의 어느 부분에 적용되는지에 따라 이름이 지정됨), 아래 코드를 실행하면 됩니다. 각 훅 포인트는 모델의 레이어 또는 모듈에 해당합니다.

# show the contents of the saes_map column for a specific row
print("SAEs in the GTP2 Small Resid Pre release")
for k,v in df.loc[df.release == "gpt2-small-res-jb", "saes_map"].values[0].items():
    print(f"SAE id: {k} for hook point: {v}")

print("-"*50)
print("SAEs in the feature splitting release")
for k,v in df.loc[df.release == "gpt2-small-res-jb-feature-splitting", "saes_map"].values[0].items():
    print(f"SAE id: {k} for hook point: {v}")

print("-"*50)
print("SAEs in the Gemma base model release")
for k,v in df.loc[df.release == "gemma-2b-res-jb", "saes_map"].values[0].items():
    print(f"SAE id: {k} for hook point: {v}")

다음으로, 특정 SAE와 이를 연결할 GPT-2 Small 사본을 로드할 것입니다. 모델을 로드하기 위해, 우리는 TransformerLens의 HookedTransformer에서 변형된 HookedSAETransformer 클래스를 사용할 것입니다.

from sae_lens import SAE, HookedSAETransformer

model = HookedSAETransformer.from_pretrained("gpt2-small", device = device)

#`cfg` 딕셔너리는 SAE와 함께 반환되며, 이는 SAE를 분석하는 데 유용한 정보를 포함할 수 있기 때문입니다 (예: 활성화 저장소 인스턴스화).  
#이는 SAE의 구성 딕셔너리와 동일하지 않으며, HF(Hugging Face) 저장소에 있던 내용을 의미하며, 여기서 SAE 구성 딕셔너리를 추출할 수 있습니다. 
# 또한 편의상 HF에 저장된 특징 희소성도 반환됩니다.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # <- Release name 
    sae_id = "blocks.7.hook_resid_pre", # <- SAE id (not always a hook point!)
    device = device
)

"sae" 객체는 SAE(희소 오토인코더 클래스)의 인스턴스입니다. SAE에는 다양한 구조가 있으며, 가중치나 활성화 함수가 다를 수 있습니다. SAE와 작업하는 과정을 간소화하기 위해 SAE Lens가 이러한 복잡성의 대부분을 처리해 줍니다.

이제 SAE 구성(config)을 살펴보고 각 매개변수를 이해해 보겠습니다:

  1. architecture: 사용 중인 SAE 구조의 유형을 지정합니다. 여기서는 표준 구조(히든 활성화가 있는 인코더와 디코더, 게이트가 없는 SAE)를 사용합니다.
  2. d_in: SAE의 입력 차원을 정의합니다. 이 구성에서는 768입니다.
  3. d_sae: SAE의 히든 레이어 차원을 설정합니다. 여기서는 24576으로, 이는 가능한 특징 활성화의 수를 나타냅니다.
  4. activation_fn_str: SAE에서 사용되는 활성화 함수를 지정합니다. 여기서는 ReLU를 사용합니다. 다른 선택지로는 TopK가 있지만 여기서는 다루지 않습니다.
  5. apply_b_dec_to_input: 디코더 바이어스를 입력에 적용할지 여부를 결정합니다. 여기서는 True로 설정되어 있습니다.
  6. finetuning_scaling_factor: 가중치 초기화 및 순방향 패스에 스케일링 팩터를 사용할지 여부를 나타냅니다. 일반적으로 사용되지 않으며, 이는 수축 문제 해결을 지원하기 위해 도입되었습니다.
  7. context_size: 컨텍스트 윈도우의 크기를 정의합니다. 이 경우 128개의 토큰입니다. 작은 프롬프트에서 훈련된 SAE는 긴 프롬프트에서는 성능이 좋지 않은 경우가 많습니다.
  8. model_name: 사용 중인 모델의 이름을 지정합니다. 여기서는 'gpt2-small'입니다. 이는 TransformerLens에서 유효한 모델 이름입니다.
  9. hook_name: SAE가 적용되는 모델의 특정 훅을 나타냅니다.
  10. hook_layer: 훅이 적용되는 레이어 번호를 지정합니다. 여기서는 7번 레이어입니다.
  11. hook_head_index: 어느 어텐션 헤드에 훅을 연결할지를 정의합니다. 여기서는 레지듀얼 스트림 SAE를 보고 있으므로 관련이 없습니다.
  12. prepend_bos: 시퀀스의 시작 토큰을 추가할지 여부를 결정합니다. 여기서는 True로 설정되어 있습니다.
  13. dataset_path: 훈련 또는 평가에 사용된 데이터셋의 경로를 지정합니다. (로컬 또는 Hugging Face 데이터셋일 수 있습니다.)
  14. dataset_trust_remote_code: 데이터셋을 로드할 때 원격 코드(Hugging Face에서 제공된 코드)를 신뢰할지 여부를 나타냅니다. 여기서는 True로 설정되어 있습니다.
  15. normalize_activations: 활성화를 어떻게 정규화할지 지정합니다. 이 구성에서는 'none'으로 설정되어 있습니다.
  16. dtype: 텐서 연산에 사용되는 데이터 유형을 정의합니다. 여기서는 32비트 부동소수점으로 설정되어 있습니다.
  17. device: 사용할 계산 장치를 지정합니다.
  18. sae_lens_training_version: SAE Lens의 훈련 버전을 나타냅니다. 여기서는 None으로 설정되어 있습니다.
  19. activation_fn_kwargs: 활성화 함수에 추가 키워드 인수를 허용합니다. 예를 들어, activation_fn_strtopk로 설정된 경우 k 값을 지정하는 데 사용할 수 있습니다.
print(sae.cfg.__dict__)

{'architecture': 'standard', 'd_in': 768, 'd_sae': 24576, 'activation_fn_str': 'relu', 'apply_b_dec_to_input': True, 'finetuning_scaling_factor': False, 'context_size': 128, 'model_name': 'gpt2-small', 'hook_name': 'blocks.7.hook_resid_pre', 'hook_layer': 7, 'hook_head_index': None, 'prepend_bos': True, 'dataset_path': 'Skylion007/openwebtext', 'dataset_trust_remote_code': True, 'normalize_activations': 'none', 'dtype': 'torch.float32', 'device': 'cuda', 'sae_lens_training_version': None, 'activation_fn_kwargs': {}, 'neuronpedia_id': 'gpt2-small/7-res-jb', 'model_from_pretrained_kwargs': {'center_writing_weights': True}}

더보기

1. architecture: 'standard'

  • 구조: SAE의 아키텍처는 "standard"입니다. 즉, 이 모델은 기본적인 인코더와 디코더 구조를 사용합니다. 인코더는 입력을 압축하고, 디코더는 이를 다시 복원하는 방식입니다. 게이트를 사용하는 SAE는 아닌 것으로 보입니다.

2. d_in: 768

  • 입력 차원: 입력 벡터의 차원이 768입니다. 이는 GPT-2와 같은 트랜스포머 모델에서 자주 사용되는 차원 크기입니다. 이 모델은 768개의 입력 노드(뉴런)를 가지고 있음을 의미합니다.

3. d_sae: 24576

  • 숨겨진 차원: SAE의 숨겨진 층(hidden layer)은 24,576개의 뉴런으로 구성되어 있습니다. 이 값은 인코더가 압축한 후 활성화될 수 있는 잠재적 특징들의 수를 의미합니다. 일반적인 트랜스포머 모델의 인코딩 단계에서 표현되는 차원보다 훨씬 큰 차원이므로, 더 많은 정보를 인코딩하는 것을 목표로 합니다.

4. activation_fn_str: 'relu'

  • 활성화 함수: ReLU (Rectified Linear Unit) 활성화 함수가 사용됩니다. 이는 음수 값을 0으로 바꾸고, 양수 값은 그대로 유지하는 함수입니다. 신경망의 비선형성을 제공하며, sparse한 특성을 도와줍니다.

5. apply_b_dec_to_input: True

  • 디코더 바이어스 적용: 디코더 바이어스를 입력에 적용합니다. 이는 인코더와 디코더의 바이어스가 동일한 방향으로 작동하는지 확인하는 역할을 할 수 있습니다.

6. finetuning_scaling_factor: False

  • 스케일링 팩터: 파인튜닝 시 스케일링 팩터를 사용하지 않습니다. 이는 특정 상황에서 가중치를 초기화하거나 순방향 전파에 적용할 수 있는 팩터입니다.

7. context_size: 128

  • 컨텍스트 창 크기: 128개의 토큰을 사용하는 컨텍스트 창입니다. 즉, 모델은 한 번에 128개의 토큰까지 문맥을 고려하여 처리합니다. 이 창의 크기는 모델이 얼마나 넓은 범위의 문맥을 고려하는지 나타냅니다.

8. model_name: 'gpt2-small'

  • 모델 이름: 이 SAE는 'gpt2-small' 모델을 기반으로 하고 있습니다. 'gpt2-small'은 약 1억 2천만 개의 파라미터를 가진 트랜스포머 기반 언어 모델입니다.

9. hook_name: 'blocks.7.hook_resid_pre'

  • 후크 위치: SAE는 GPT-2 모델의 7번째 블록의 resid_pre 부분에 후크를 적용합니다. 이는 GPT-2의 잔차 연결 이전에 활성화 값을 추출하여 SAE가 학습하거나 변형을 가할 수 있음을 의미합니다.

10. hook_layer: 7

  • 후크 레이어: 모델의 7번째 레이어에 후크가 적용됩니다. 트랜스포머의 여러 레이어 중 7번째 레이어에서 데이터를 추출하거나 변환합니다.

11. hook_head_index: None

  • 어텐션 헤드 인덱스: 어텐션 헤드 인덱스는 사용되지 않습니다. 이는 어텐션 메커니즘이 아닌, 잔차 스트림(residual stream)에 적용되는 SAE임을 나타냅니다.

12. prepend_bos: True

  • BOS 토큰 추가: 입력 시퀀스에 시작 토큰(Beginning-of-Sequence token)을 추가합니다. 이는 시퀀스의 시작을 명시해 주며, GPT 모델이 컨텍스트를 올바르게 이해하도록 돕습니다.

13. dataset_path: 'Skylion007/openwebtext'

  • 데이터셋 경로: 'Skylion007/openwebtext'라는 데이터셋을 사용합니다. 이는 공개된 웹 텍스트 데이터를 기반으로 모델을 학습하거나 평가할 때 사용되는 경로입니다.

14. dataset_trust_remote_code: True

  • 원격 코드 신뢰: 데이터셋을 로드할 때 원격 코드(HuggingFace와 같은 리포지토리)를 신뢰하도록 설정되어 있습니다.

15. normalize_activations: 'none'

  • 활성화 정규화: 활성화 값은 정규화되지 않습니다. 이는 각 뉴런의 활성화 값이 조정되지 않고 그대로 사용된다는 것을 의미합니다.

16. dtype: 'torch.float32'

  • 데이터 타입: 32비트 부동 소수점(floating point) 형식이 사용됩니다. 이는 모델이 연산할 때 사용하는 기본 데이터 타입입니다.

17. device: 'cuda'

  • 연산 장치: 이 SAE는 GPU (CUDA)를 사용해 연산을 수행합니다. GPU를 사용하면 모델이 훨씬 더 빠르게 학습하거나 예측을 수행할 수 있습니다.

18. sae_lens_training_version: None

  • SAE Lens 훈련 버전: 훈련 시 사용된 SAE Lens의 버전이 명시되어 있지 않으며, None으로 설정되어 있습니다.

19. activation_fn_kwargs: {}

  • 활성화 함수 추가 인자: 활성화 함수에 추가적인 파라미터는 사용되지 않았습니다. 만약 'topk'와 같은 활성화 함수가 사용되었다면, 추가 인자로 k 값을 설정했을 것입니다.

20. neuronpedia_id: 'gpt2-small/7-res-jb'

  • Neuronpedia ID: 이 ID는 이 SAE의 특정 구성 요소나 레이어를 설명하는 데 사용되는 고유한 ID입니다. 이를 통해 이 모델에 대한 특정 참조를 Neuronpedia에서 찾을 수 있습니다.

21. model_from_pretrained_kwargs: {'center_writing_weights': True}

  • 사전 훈련된 모델 파라미터: 'center_writing_weights'가 True로 설정되어 있습니다. 이는 특정 가중치 설정 방식에 대한 선택입니다.

차원 관련 요약

  • 입력 차원 (d_in): 768
  • 숨겨진 차원 (d_sae): 24,576
  • 컨텍스트 크기 (context_size): 128

이 구조는 입력 차원이 768이고, 숨겨진 레이어에서 24,576개의 뉴런을 통해 더욱 고차원적인 특징을 학습하는 SAE를 기반으로 합니다.

 

다음으로, 작업할 데이터셋을 불러와야 합니다. 우리는 Pile의 샘플을 사용할 것입니다.

from datasets import load_dataset  
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model.tokenizer, # type: ignore
    streaming=False,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

만약 오류가 발생한다면 이렇게 pandas를 설치해주세요

 

pip install pandas

 

Basics: What are SAE Features?

Opening a feature dashboard on Neuronpedia

다양한 SAE 기능들을 살펴보기 전에, 기본적인 질문부터 해결해봅시다: SAE 기능이란 무엇일까요?

SAE 기능은 오토인코더가 입력 데이터에서 감지하도록 학습한 패턴이나 개념을 나타냅니다. 이러한 기능은 종종 의미론적, 구문적이거나 텍스트에서 해석 가능한 요소들에 해당하며, 활성화 공간에서 선형 방향을 나타냅니다. SAE는 모델의 특정 부분의 활성화에 대해 훈련되며, 훈련 후에는 이러한 기능들이 SAE의 숨겨진 레이어에서 활성화로 나타납니다(이는 원래의 활성화 벡터보다 훨씬 넓으며, 각 기능에 대해 하나의 숨겨진 활성화를 생성합니다). 따라서 숨겨진 활성화는 원래 모델의 활성화에서 얽히거나 중첩된 기능을 분해한 결과를 나타냅니다. 이상적으로, 이러한 활성화는 희소성을 가집니다. 즉, 주어진 입력에 대해 많은 가능성 중 일부만 실제로 활성화됩니다. 이 희소성은 해석 가능성의 용이성과 관련이 있습니다.

여기 표시된 대시보드는 단일 SAE 기능에 대한 자세한 정보를 제공합니다. (셀을 새로 고치면 더 많은 예시를 볼 수 있습니다). 그 구성 요소들을 나누어 보겠습니다:

  1. 기능 설명: 맨 위에는 자동 해석 소스에서 추출된 기능 설명이 표시됩니다.
  2. 로짓 플롯: 이 기능의 가장 강한 긍정적 및 부정적 로짓을 보여줍니다. 값은 연관성의 강도를 나타냅니다.
  3. 활성화 밀도 플롯: 이 히스토그램은 무작위로 샘플링된 데이터셋에서 이 기능의 활성화 값 분포를 보여줍니다. x축은 활성화 강도를 나타내고, y축은 빈도를 나타냅니다. 첫 번째 차트는 단순히 0이 아닌 활성화 분포를, 두 번째 플롯은 부정적 및 긍정적 로짓의 밀도를 보여줍니다.
  4. 테스트 활성화: 이 기능을 노트북이나 Neuronpedia 내에서 사용할 수 있습니다. 텍스트를 입력하면 해당 기능이 텍스트 전반에서 어떻게 활성화되는지 확인할 수 있습니다.
  5. 최상위 활성화: 플롯 아래에는 이 기능을 강하게 활성화시키는 텍스트 스니펫의 예시가 나옵니다. 각 스니펫은 활성화가 나타나는 부분이 강조 표시됩니다.

더 자세한 내용은 Towards Monosemanticity의 이 섹션을 참조하세요.

from IPython.display import IFrame

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=feature_idx)
IFrame(html, width=1200, height=600)

위에서 무작위로 선택된 기능에 대해, 어떤 텍스트가 해당 기능을 활성화시킬지 예측할 수 있나요? 그리고 그 이론을 테스트할 수 있을까요?

예를 들어, 만약 해당 기능이 포켓몬에 반응하는 것처럼 보인다면, Digimon(애완 몬스터가 등장하는 비슷한 게임)에서도 해당 기능이 반응하는지 테스트해보는 것은 다른 이야기를 제시할 수 있습니다.

Autointerp 다운로드 / 검색

특정 주제와 관련된 기능을 검색하고 싶다면 어떻게 할까요? 그럴 때는 설명 검색 API를 사용할 수 있습니다. SAE 기능에 대한 모든 Autointerp 설명을 다운로드한 후 이를 Pandas 데이터프레임으로 로드하면 됩니다. Neuronpedia API 문서가 여기서 유용할 것입니다: Neuronpedia API 문서.

참고: SAE Lens에 있는 모든 SAE가 Neuronpedia에 있는 것은 아니며, Neuronpedia에 있는 모든 SAE가 모든 기능에 대한 Autointerp를 갖추고 있는 것도 아닙니다. 이는 진행 중인 작업입니다.

import requests

url = "https://www.neuronpedia.org/api/explanation/export?modelId=gpt2-small&saeId=7-res-jb"
headers = {"Content-Type": "application/json"}

response = requests.get(url, headers=headers)
# convert to pandas
data = response.json()
explanations_df = pd.DataFrame(data)
# rename index to "feature"
explanations_df.rename(columns={"index": "feature"}, inplace=True)
# explanations_df["feature"] = explanations_df["feature"].astype(int)
explanations_df["description"] = explanations_df["description"].apply(lambda x: x.lower())
explanations_df

성경과 관련된 기능을 검색해봅시다.

bible_features = explanations_df.loc[explanations_df.description.str.contains(" bible")]
bible_features
# Let's get the dashboard for this feature.
html = get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=bible_features.feature.values[0])
IFrame(html, width=1200, height=600)

기본: SAE를 사용하여 기능 찾기

Autointerp는 기능을 찾는 데 있어 매우 비효율적인 방법입니다. 우리는 실제 프롬프트에서 모델 예측을 이해하는 데 더 관심이 있습니다, 특히 SAE를 사용하여. 이제 이 성경 구절을 완성하는 데 사용된 기능들을 확인해봅시다. 성경과 관련된 기능을 볼 수 있을까요?

from transformer_lens.utils import test_prompt

prompt = "In the beginning, God created the heavens and the"
answer = "earth"

# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, model)

Tokenized prompt: ['<|endoftext|>', 'In', ' the', ' beginning', ',', ' God', ' created', ' the', ' heavens', ' and', ' the']
Tokenized answer: [' earth']
Performance on answer token:
Rank: 0 Logit: 27.64 Prob: 99.32% Token: | earth|
Top 0th token. Logit: 27.64 Prob: 99.32% Token: | earth|
Top 1th token. Logit: 22.46 Prob: 0.56% Token: | Earth|
Top 2th token. Logit: 19.20 Prob: 0.02% Token: | planets|
Top 3th token. Logit: 18.80 Prob: 0.01% Token: | moon|
Top 4th token. Logit: 18.07 Prob: 0.01% Token: | heavens|
Top 5th token. Logit: 17.67 Prob: 0.00% Token: | oceans|
Top 6th token. Logit: 17.43 Prob: 0.00% Token: | ten|
Top 7th token. Logit: 17.41 Prob: 0.00% Token: | stars|
Top 8th token. Logit: 17.38 Prob: 0.00% Token: | seas|
Top 9th token. Logit: 17.35 Prob: 0.00% Token: | four|
Ranks of the answer tokens: [(' earth', 0)]

HookedSAETransformer 사용하기

HookedSAE Transformer 클래스를 사용하여 모델을 실행하는 방법에 대한 전체 튜토리얼이 있습니다. ->
Open In Colab

여기에서는 이 클래스를 사용하여 기능을 얻는 방법을 간단히 시연하겠습니다.

#SAE는 활성화를 완벽하게 재구성하지 않기 때문에, SAE를 연결하고 모델 성능을 유지하고 싶다면 **오류 항(term)**을 사용해야 합니다. 이는 SAE가 순전파(forward pass)를 수정하는 데 사용되며, 활성화를 잘 재구성하지 못하면 출력에 영향을 미칠 수 있기 때문입니다.
#좋은 SAE는 작은 오류 항을 가지지만, 이것은 주의해야 할 부분입니다.

sae.use_error_term #만약 오류 항 사용 설정이 false로 되어 있다면, 순전파를 SAE를 사용해 수정하게 됩니다.

False

아래에서는 run_with_cache_with_saes라는 HookedSAETransformer의 함수를 사용합니다. 이 함수는 모든 캐시된 활성화(우리가 인자로 지정한 SAE의 활성화 포함)를 제공합니다. 프롬프트를 모델에 실행하면 다음과 같은 활성화 텐서를 얻게 됩니다.

# hooked SAE Transformer will enable us to get the feature activations from the SAE
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])

print([(k, v.shape) for k,v in cache.items() if "sae" in k])

# note there were 11 tokens in our prompt, the residual stream dimension is 768, and the number of SAE features is 768

[('blocks.7.hook_resid_pre.hook_sae_input', torch.Size([1, 11, 768])), ('blocks.7.hook_resid_pre.hook_sae_acts_pre', torch.Size([1, 11, 24576])), ('blocks.7.hook_resid_pre.hook_sae_acts_post', torch.Size([1, 11, 24576])), ('blocks.7.hook_resid_pre.hook_sae_recons', torch.Size([1, 11, 768])), ('blocks.7.hook_resid_pre.hook_sae_output', torch.Size([1, 11, 768]))]

더보기
 

이 출력은 SAE(Sparse Autoencoder)의 여러 단계에서 입력과 출력의 차원을 보여주고 있어요. 각 단계가 어떤 역할을 하는지 하나씩 차근차근 설명해 볼게요.

1. blocks.7.hook_resid_pre.hook_sae_input

  • 차원: torch.Size([1, 11, 768])
  • 설명: 이 값은 SAE에 들어가기 전 입력의 크기를 나타냅니다.
    • 1: 배치 크기 (한 번에 처리하는 문장의 수, 여기서는 1개 문장을 처리)
    • 11: 문장의 토큰 개수 (이 문장에는 11개의 단어 또는 토큰이 있습니다)
    • 768: 각 토큰을 768차원 벡터로 나타냅니다 (GPT-2 모델에서 일반적으로 사용되는 차원 크기).

무슨 뜻이냐면: 이 SAE에 들어가기 전에 모델은 11개의 토큰으로 이루어진 문장을 받았고, 각각의 토큰은 768차원의 벡터로 표현되고 있어요.

2. blocks.7.hook_resid_pre.hook_sae_acts_pre

  • 차원: torch.Size([1, 11, 24576])
  • 설명: SAE의 인코더를 통과한 후, 즉 인코더가 데이터를 변환한 후의 크기를 보여줍니다.
    • 24576: 인코더가 출력하는 고차원 특징 벡터의 크기입니다.

무슨 뜻이냐면: 인코더는 입력 차원인 768을 24576차원으로 크게 확장해서 더 많은 정보를 담는 벡터를 만들어내요. 이 단계는 중요한 특징들을 많이 추출하는 역할을 합니다.

3. blocks.7.hook_resid_pre.hook_sae_acts_post

  • 차원: torch.Size([1, 11, 24576])
  • 설명: 이 값은 활성화 함수(여기서는 ReLU)가 적용된 후의 벡터 크기입니다. 차원은 인코더 결과와 동일한데, 활성화 함수로 인해 양수 값만 남고 음수는 0으로 바뀝니다.

무슨 뜻이냐면: 활성화 함수(ReLU)를 통해 인코더가 추출한 특징들 중에서 유용한 값들만 남기고, 나머지는 0으로 바꾼 거예요. 이 과정을 통해 결과가 더 sparse(희소)하게 됩니다.

4. blocks.7.hook_resid_pre.hook_sae_recons

  • 차원: torch.Size([1, 11, 768])
  • 설명: 이 단계는 디코더를 통해 원래 입력 차원으로 다시 복원된 결과를 나타냅니다.
    • 24576차원에서 다시 원래 차원인 768차원으로 돌아왔습니다.

무슨 뜻이냐면: 디코더가 인코더의 결과를 바탕으로 입력 데이터를 다시 원래 차원인 768로 복원한 거예요. 이 과정에서 중요한 정보가 보존되도록 최적화됩니다.

5. blocks.7.hook_resid_pre.hook_sae_output

  • 차원: torch.Size([1, 11, 768])
  • 설명: 최종 출력 차원은 복원된 768차원입니다. 다시 입력과 같은 차원으로 돌아왔어요.

무슨 뜻이냐면: SAE를 거친 후, 결국 입력과 같은 크기의 벡터로 다시 변환되었고, 이 벡터는 중요한 정보를 더 많이 담고 있는 복원된 형태입니다.


간단 정리

  • 입력: [1,11,768][1, 11, 768] → 11개의 토큰, 각 토큰은 768차원 벡터.
  • 인코더 출력: [1,11,24576][1, 11, 24576] → 인코더가 768차원을 24576차원으로 확장.
  • 활성화 후: [1,11,24576][1, 11, 24576] → 활성화 함수(ReLU)를 통해 sparse한 특징 벡터 생성.
  • 디코더 복원: [1,11,768][1, 11, 768] → 디코더가 다시 원래 차원으로 복원.
  • 최종 출력: [1,11,768][1, 11, 768] → 입력과 같은 크기의 복원된 벡터.

이 과정에서 SAE는 중요한 정보를 추출해 더 큰 차원으로 확장했다가, 다시 입력과 같은 차원으로 복원하는 작업을 합니다.

다음으로, 프롬프트의 마지막 토큰 위치에서 SAE의 숨겨진 레이어 활성화를 시각화해 보겠습니다. 이 세로선들은 각각의 기능 활성화에 해당합니다. 또한, 활성화된 각 기능에 대한 대시보드를 시각화할 수 있는데, 이는 활성화 캐시에서 해당 기능의 위치를 인덱스로 사용하여 Neuronpedia에서 데이터를 가져옵니다. 우리는 상위 기능들만을 대상으로 이를 수행할 것입니다.

# 이제 레이어 8에서 마지막 토큰 위치에서 어떤 기능들이 활성화되었는지 살펴보겠습니다.

# 선 위에 마우스를 올리면 기능 ID를 확인할 수 있습니다.
px.line(
    cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu().numpy(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
).show()

# let's print the top 5 features and how much they fired
vals, inds = torch.topk(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :], 5)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} fired {val:.2f}")
    html = get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=ind)
    display(IFrame(html, width=1200, height=300))

대비 쌍(Contrast Pairs) 기법

때로는 두 프롬프트 사이에서 어떤 기능이 다르게 활성화되는지 궁금할 수 있습니다. 결과적으로 나타나는 활성화를 비교하여 이 질문을 조사해 봅시다. 아래 프롬프트를 사용하면 로짓 예측이 상당히 달라지는 것을 확인할 수 있습니다.

from transformer_lens.utils import test_prompt

prompt = "In the beginning, God created the cat and the"
answer = "earth"

# here we see that removing the word "Heavens" is very effective at making the model no longer predict "earth".
# instead the model predicts a bunch of different animals.
# Can we work out which features fire differently which might explain this? (This is a toy example not meant to be super interesting)
test_prompt(prompt, answer, model)

Tokenized prompt: ['<|endoftext|>', 'In', ' the', ' beginning', ',', ' God', ' created', ' the', ' cat', ' and', ' the']
Tokenized answer: [' earth']
Performance on answer token:
Rank: 34 Logit: 13.42 Prob: 0.30% Token: | earth|
Top 0th token. Logit: 18.10 Prob: 32.26% Token: | dog|
Top 1th token. Logit: 17.59 Prob: 19.38% Token: | mouse|
Top 2th token. Logit: 16.10 Prob: 4.37% Token: | lamb|
Top 3th token. Logit: 15.72 Prob: 2.98% Token: | woman|
Top 4th token. Logit: 15.34 Prob: 2.03% Token: | bear|
Top 5th token. Logit: 15.22 Prob: 1.80% Token: | rabbit|
Top 6th token. Logit: 15.11 Prob: 1.63% Token: | bird|
Top 7th token. Logit: 15.03 Prob: 1.50% Token: | goat|
Top 8th token. Logit: 14.97 Prob: 1.41% Token: | fox|
Top 9th token. Logit: 14.81 Prob: 1.20% Token: | beast|
Ranks of the answer tokens: [(' earth', 34)]

여기에서 우리는 "Heavens"라는 단어를 제거하면 모델이 더 이상 "earth"를 예측하지 않게 되는 것을 볼 수 있습니다. 대신 모델은 다양한 동물들을 예측합니다.
이 차이를 설명할 수 있는, 다르게 활성화되는 기능들을 찾아낼 수 있을까요? (이것은 흥미로운 예시를 위한 것이 아닌 간단한 예시입니다.)

두 개의 활성화 벡터를 플롯해 봅시다.

prompt = ["In the beginning, God created the heavens and the", "In the beginning, God created the cat and the"]
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
print([(k, v.shape) for k,v in cache.items() if "sae" in k])

feature_activation_df = pd.DataFrame(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu().numpy(),
                                     index = [f"feature_{i}" for i in range(sae.cfg.d_sae)],
)
feature_activation_df.columns = ["heavens_and_the"]
feature_activation_df["cat_and_the"] = cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][1, -1, :].cpu().numpy()
feature_activation_df["diff"]= feature_activation_df["heavens_and_the"] - feature_activation_df["cat_and_the"]

fig = px.line(
    feature_activation_df,
    title="Feature activations for the prompt",
    labels={"index": "Feature", "value": "Activation"},
)

# hide the x-ticks
fig.update_xaxes(showticklabels=False)
fig.show()

더보기

이 코드는 GPT-2 모델을 활용하여 특정 문장의 특징(feature)을 분석하고, 두 문장 간의 차이를 시각화하는 작업을 수행합니다. 아래에 단계별로 설명해 드릴게요.

1. 입력 문장 설정

prompt = ["In the beginning, God created the heavens and the", "In the beginning, God created the cat and the"]
  • 두 개의 문장을 prompt 리스트로 정의합니다. 첫 번째 문장은 "In the beginning, God created the heavens and the", 두 번째 문장은 "In the beginning, God created the cat and the"입니다.
  • 이 두 문장을 기반으로 각 문장의 의미를 어떻게 인코딩하는지 비교할 것입니다.

2. 모델 실행 및 캐시 획득

_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
  • model.run_with_cache_with_saes(prompt, saes=[sae]): GPT-2 모델을 **SAE(Sparse Autoencoder)**와 함께 실행하여 prompt의 결과를 계산합니다.
  • 이 함수는 모델의 각 단계에서 데이터를 저장하는 **캐시(cache)**와 함께 실행됩니다.
  • 캐시에는 모델의 특정 레이어에서 SAE가 처리한 활성화 값들이 저장됩니다.

3. 캐시에서 SAE 관련 값 추출 및 출력

print([(k, v.shape) for k,v in cache.items() if "sae" in k])
 
  • 캐시 안에서 SAE와 관련된 값들의 이름(k)과 해당 값의 차원 정보(v.shape)를 출력합니다.
  • if "sae" in k 부분은 SAE 관련 캐시 데이터만 필터링하여 출력하는 조건입니다.

4. 특징 활성화 값을 데이터프레임으로 변환

feature_activation_df = pd.DataFrame(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu().numpy(), index = [f"feature_{i}" for i in range(sae.cfg.d_sae)], )
  • 첫 번째 문장SAE 활성화 값을 가져와 데이터프레임으로 변환합니다.
    • cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :]: 첫 번째 문장(0), 마지막 토큰(-1), 모든 특성 벡터(:)의 값을 가져옵니다. 이는 24,576차원(특징 벡터)의 활성화 값을 의미합니다.
    • .cpu().numpy(): GPU에서 처리된 텐서를 CPU로 이동한 후, 넘파이 배열로 변환합니다.
    • index = [f"feature_{i}" for i in range(sae.cfg.d_sae)]: 각 특징 벡터에 대해 특징 번호를 설정합니다. 여기서 sae.cfg.d_sae는 24,576으로 설정되어 있으므로, "feature_0", "feature_1", ..., "feature_24575"까지의 인덱스를 가집니다.

5. 두 번째 문장의 특징 활성화 값 추가

feature_activation_df.columns = ["heavens_and_the"] feature_activation_df["cat_and_the"] = cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][1, -1, :].cpu().numpy()
  • 데이터프레임의 첫 번째 열(heavens_and_the)은 첫 번째 문장의 특징 활성화 값을 가집니다.
  • 두 번째 열(cat_and_the)은 두 번째 문장의 마지막 토큰에 대한 SAE 활성화 값을 추가합니다.

6. 두 문장 간의 차이 계산

feature_activation_df["diff"]= feature_activation_df["heavens_and_the"] - feature_activation_df["cat_and_the"]
  • 첫 번째 문장과 두 번째 문장 사이의 특징 활성화 값 차이를 계산하여 diff 열에 저장합니다.
  • 이 차이를 분석함으로써 두 문장이 어떤 부분에서 차이가 나는지를 알 수 있습니다.

7. 시각화

fig = px.line( feature_activation_df, title="Feature activations for the prompt", labels={"index": "Feature", "value": "Activation"}, )
  • plotly.express.line을 사용해 특징 활성화 값의 변화를 그래프로 시각화합니다.
    • 각 특징 벡터의 활성화 값이 어떻게 변하는지 보여주는 그래프입니다.
    • labels: x축에는 각 특징 번호가, y축에는 각 특징 벡터의 활성화 값이 표시됩니다.

8. X축의 눈금 숨기기

 
fig.update_xaxes(showticklabels=False)
  • x축(특징 번호)에 눈금이 너무 많으므로, x축의 눈금을 숨깁니다. 이렇게 하면 그래프가 더 깔끔하게 보입니다.

9. 그래프 출력

fig.show()
  • 그래프를 표시합니다. 이 그래프는 각 특징의 활성화 값 변화와 두 문장 간의 차이를 보여줍니다.

요약:

  1. 두 문장을 모델에 입력하고, 각 문장의 마지막 단어에 대해 특징 벡터를 추출합니다.
  2. 이 특징 벡터들은 24,576개의 활성화 값으로 이루어져 있고, 각 특징이 어떻게 활성화되었는지 비교할 수 있습니다.
  3. 두 문장의 차이를 계산하고, 이를 그래프로 시각화하여 두 문장이 모델 내부에서 어떻게 다르게 처리되었는지 보여줍니다.

 

이 그래프는 두 문장에 대한 특징 활성화 값을 비교하고, 각 특징 벡터가 어떻게 활성화되었는지 시각화한 것입니다. 각 축과 그래프의 색을 기준으로 어떻게 해석할 수 있는지 단계별로 설명해 드릴게요.

1. X축 (Feature):

  • X축은 특징 벡터(Feature) 번호를 나타냅니다. 각 특징 벡터는 인코더를 통해 추출된 24,576개의 고차원적인 정보를 의미합니다.
  • X축의 각 점은 하나의 특정 특징 벡터를 의미합니다. 하지만 너무 많기 때문에 실제로 눈금이 표시되진 않았습니다.

2. Y축 (Activation):

  • Y축은 각 특징 벡터의 활성화 값(Activation)을 나타냅니다. 활성화 값은 특정 특징 벡터가 얼마나 강하게 활성화되었는지를 나타내며, 값이 클수록 해당 특징이 두 문장에서 더 중요하게 처리되었음을 의미합니다.

3. 세 가지 색깔:

  • 파란색(heavens_and_the): 첫 번째 문장, "In the beginning, God created the heavens and the"에 대한 특징 벡터의 활성화 값을 나타냅니다.
  • 빨간색(cat_and_the): 두 번째 문장, "In the beginning, God created the cat and the"에 대한 특징 벡터의 활성화 값을 나타냅니다.
  • 초록색(diff): 두 문장의 활성화 값 차이(파란색 - 빨간색)를 나타냅니다. 이 값이 양수면 첫 번째 문장이 더 강하게 활성화되었고, 음수면 두 번째 문장이 더 강하게 활성화되었음을 의미합니다.

4. 그래프 해석:

  • 특징 벡터 간 차이:
    • 초록색 선이 크게 위나 아래로 움직일 때, 두 문장 사이의 차이가 큽니다. 즉, 해당 특징 벡터가 두 문장에서 다르게 반응한다는 것을 의미합니다.
    • 예를 들어, 특정 특징 벡터에서 초록색 선이 크게 위로 올라가면 첫 번째 문장(heavens)이 더 중요하게 활성화된 것이고, 크게 아래로 내려가면 두 번째 문장(cat)이 더 중요하게 활성화된 것입니다.
  • 유사성:
    • 초록색 선이 0에 가까울수록 두 문장의 특징 벡터 활성화 값이 거의 같다는 의미입니다. 이는 두 문장이 해당 특징에서 거의 같은 영향을 미친다는 것을 나타냅니다.
  • 특징별 활성화 분포:
    • 활성화 값이 0에서 크게 벗어나지 않는 대부분의 특징 벡터는 두 문장에서 거의 차이가 없거나 중요하지 않게 작용한 것으로 보입니다.
    • 특정 특징에서만 눈에 띄게 활성화 값 차이가 발생하고 있습니다. 이는 문장 말미의 단어(heavens, cat)가 해당 특징에 큰 영향을 미쳤음을 시사합니다.

5. 특정 특징 벡터에서의 활성화 차이:

  • 파란색과 빨간색 막대가 서로 다른 높이를 가지고 있거나 방향이 다를 때, 두 문장은 그 특징 벡터에서 다른 반응을 보였다는 뜻입니다.
  • 예를 들어, 그래프 중앙 부분에 있는 특징 벡터에서 파란색이 크게 활성화된 부분이 있습니다. 이는 "heavens"가 이 특징 벡터에서 더 강하게 반응했음을 의미합니다.
  • 반대로, 그래프 중간과 오른쪽에서 빨간색이 더 크게 활성화된 곳은 "cat"이 해당 특징 벡터에서 더 큰 영향을 미쳤다는 것을 뜻합니다.

요약:

  • 이 그래프는 두 문장("heavens"와 "cat"이 포함된 문장)이 GPT-2 모델 내부에서 어떤 특징 벡터에 더 많이 반응했는지를 시각화한 것입니다.
  • 특정 특징 벡터에서 "heavens"가 더 중요하게 활성화되거나, 반대로 "cat"이 더 중요하게 작용한 특징 벡터를 알 수 있습니다.
  • 초록색 선을 통해 두 문장 사이에서 어떤 특징이 더 중요한지, 그리고 그 차이가 얼마나 큰지를 한눈에 볼 수 있습니다.

그래프는 문장의 특정 단어가 모델 내부에서 어떻게 처리되는지, 특히 SAE가 어떤 특징을 더 중요하게 보는지를 파악하는 데 중요한 정보를 제공합니다.

차이점이 있다는 것을 알 수 있지만, 가장 큰 차이를 보이는 기능들의 대시보드를 플롯하여 그것들이 무엇인지 살펴봅시다. 가장 큰 차이점은 이제 "동물" 기능이 활성화되었다는 것입니다.

치킨과 샐러드를 비교했는데 이것도 확실하게 비교가 되네요

# let's look at the biggest features in terms of absolute difference

diff = cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][1, -1, :].cpu() - cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu()
vals, inds = torch.topk(torch.abs(diff), 5)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} had a difference of {val:.2f}")
    html = get_dashboard_html(sae_release = "gpt2-small", sae_id="7-res-jb", feature_idx=ind)
    display(IFrame(html, width=1200, height=300))

이제 고양이와 함께, 동물을 예측하는 기능이 강하게 활성화되고 있으며, "and"에 반응하며 "계곡" 및 다른 지질학적 용어를 촉진하는 기능은 더 이상 활성화되지 않는 것을 볼 수 있습니다.

더보기

이 이미지에서는 두 개의 특징(Feature)이 활성화된 상황을 비교하고, 각각의 특징이 긍정적(Positive) 또는 부정적(Negative)으로 얼마나 활성화되었는지에 대한 정보를 시각적으로 제공하고 있습니다. 각 특징 벡터가 어떻게 작동하는지를 더 잘 이해할 수 있도록 차근차근 해석해 보겠습니다.

1. Feature 12952: Difference of 7.91

  • 상위 설명: "terms related to specific animals, particularly mammals like cows, oxen, bulls, calves, tigers, lions, bears, horses, foxes, snakes, and wild boars"
    • 이 특징은 동물, 특히 포유류와 관련된 용어에서 높은 활성화를 보입니다.
  • Negative Logits:
    • 부정적으로 활성화된 단어들입니다. 여기서 quickshipAvailab, FUL, promotions, unification 같은 단어들이 부정적인 방향으로 작용하고 있습니다.
    • 이 단어들은 이 특징 벡터가 잘 활성화되지 않는 경우에 등장할 가능성이 큽니다.
  • Positive Logits:
    • 긍정적으로 활성화된 단어들입니다. goats, reptiles, nests, turtle, mammals 등의 단어들이 긍정적인 방향으로 작용하고 있습니다.
    • 이 단어들은 이 특징이 잘 활성화될 때 나타날 가능성이 큽니다. 주로 동물과 관련된 단어들이 긍정적인 활성화를 보입니다.
  • Activation Density:
    • 위쪽 그래프는 활성화된 값의 분포를 보여줍니다. 노란색 부분은 강하게 활성화된 값을 나타내고, 아래쪽 그래프는 부정적/긍정적 활성화 값의 히스토그램을 나타냅니다.
    • 이 그래프를 보면 양수 활성화 값이 다소 많으며, 특히 동물과 관련된 용어가 등장할 때 활성화가 더 강해지는 경향을 보여줍니다.

2. Feature 10195: Difference of 6.11

  • 상위 설명: "instances of the word 'and' followed by other words, particularly when multiple occurrences are close together with high activation values"
    • 이 특징은 "and"라는 단어가 다른 단어들과 연달아 나오는 경우에 활성화됩니다. 특히 "and"가 자주 등장할 때 이 특징이 강하게 활성화됩니다.
  • Negative Logits:
    • 부정적인 활성화 단어들입니다. 여기서는 advertisement, arro, edia, prosecut 같은 단어들이 부정적으로 작동하고 있습니다.
    • 이 단어들은 이 특징 벡터가 잘 활성화되지 않는 경우에 나타날 가능성이 큽니다.
  • Positive Logits:
    • 긍정적인 활성화 단어들입니다. valleys, uries, oranges, Soviets, necks 같은 단어들이 긍정적으로 작용하고 있습니다.
    • 이 단어들은 이 특징이 활성화될 때 주로 등장하며, 특히 "and"가 여러 단어와 함께 자주 나타날 때 그 영향이 큽니다.
  • Activation Density:
    • 상단 그래프는 이 특징이 강하게 활성화되는 경우와 약하게 활성화되는 경우의 분포를 보여줍니다.
    • 하단 히스토그램을 보면 양수 영역에서 더 많은 활성화가 발생하고 있으며, 이는 "and"가 자주 등장할 때 활성화되는 경향을 잘 보여줍니다.

전체적인 해석:

  1. Feature 12952:
    • 이 특징 벡터는 주로 동물, 특히 포유류와 관련된 단어에 대해 긍정적으로 활성화됩니다.
    • 동물과 관련 없는 단어들이 등장하면 부정적으로 활성화되며, 특정 동물 관련 용어에서 활성화가 강하게 발생합니다.
  2. Feature 10195:
    • 이 특징 벡터는 "and"가 자주 등장하는 경우에 활성화됩니다. 특히, "and" 뒤에 연달아 나오는 단어들에서 강한 활성화가 발생합니다.
    • "and"가 여러 번 등장하는 문장 구조에서 긍정적인 활성화가 크게 발생합니다.

이 이미지에서 보여주는 특징 벡터의 활성화는 모델이 문장 내에서 특정 단어 패턴(동물 관련 용어, "and"와 같은 연결어 구조 등)을 어떻게 인식하고 반응하는지를 시각적으로 설명합니다.

기능 대시보드 만들기 (선택 사항)

관심 있는 분들을 위해, 기능 대시보드의 구성 요소를 생성하는 방법을 보여주는 섹션을 제공합니다.

기능 대시보드가 무엇을 표시하는지 다루었지만, 플롯이 무엇을 의미하는지 완전히 이해하기 위해 더 자세히 살펴보겠습니다. 위 설명을 반복하면서 더 세부적인 정보를 제공하자면, 기본 기능 대시보드는 4개의 주요 구성 요소로 이루어져 있습니다:

  1. 기능 활성화 분포: 해당 기능이 활성화되는 토큰의 비율을 보고하며, 보통 100개에서 10,000개 중 1개의 토큰에서 활성화됩니다. 또한, 긍정적 활성화의 분포를 보여줍니다.
  2. 로짓 가중치 분포: 디코더 가중치를 언임베드에 투영한 값으로, 대략적으로 해당 기능이 어떤 토큰을 촉진하는지 알 수 있습니다. 큰 모델이나 중간 레이어에서는 그 유용성이 줄어듭니다.
  3. 로짓 가중치 분포에서 상위 10개 및 하위 10개 토큰 (긍정/부정 로짓).
  4. 최대 활성화 예시: 기능이 활성화되는 텍스트 예시들로, 보통 해당 기능이 무엇을 의미하는지 이해하는 데 가장 많은 정보를 제공합니다.

보너스 섹션: Not all Language Model Features are Linear에서 원형 부분공간 기하학 재현하기

Neuronpedia는 기능 대시보드를 호스팅하는 웹사이트로, 서버를 실행해 모델을 돌리고 기능 활성화를 확인할 수 있습니다. 이 덕분에, 기능이 실제로 활성화되어야 할 텍스트 분포에서 올바르게 활성화되는지 매우 쉽게 확인할 수 있습니다. 위의 플롯 중 일부에서 Neuronpedia에서 데이터를 다운로드했습니다.

로컬: 최대 활성화 예시 찾기

우리는 먼저 최대 활성화 예시, 즉 특정 기능으로부터 가장 높은 활성화 수준을 보여주는 프롬프트를 찾는 것으로 시작합니다. 먼저, SAE의 원래 훈련 데이터셋에서 텍스트 샘플을 스트리밍하고 이에 대한 활성화를 생성하는 기능 저장소를 준비할 것입니다.

instantiate an object to hold activations from a dataset

from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

def list_flatten(nested_list):
    return [x for y in nested_list for x in y]

# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens, len_prefix=5, len_suffix=3, model = model):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        prompt=prompt,
        pos=pos,
        label=label,
    ))

이제 무작위로 선택된 기능 세트에 대한 예시를 생성하겠습니다.

다음 코드는 다음 작업을 수행합니다 (무작위로 선택된 100개의 기능에 대해):

  1. 데이터셋에서 토큰을 샘플링하고, SAE가 bos(시작 토큰)으로 훈련된 경우 이를 미리 추가하고 프롬프트가 SAE에 맞는 올바른 크기인지 확인합니다.
  2. 활성화를 생성하고, 어떤 토큰에서 기능이 활성화되었는지 추적합니다.
  3. (Not all language model features are linear용) 해당 기능들이 생성한 부분 공간을 추적합니다.
  4. 적어도 하나의 기능이 활성화된 모든 프롬프트에서 모든 토큰을 포함하는 데이터프레임을 만듭니다.

참고: 이 코드는 데이터프레임 병합 과정과 캐시된 활성화를 사용하지 않고 실제로 모델을 실행해야 하기 때문에 비교적 느립니다. SAE Lens는 대시보드 생성을 위해 SAE Dashboard를 사용할 것을 공식적으로 권장합니다.

# finding max activating examples is a bit harder. To do this we need to calculate feature activations for a large number of tokens
feature_list = torch.randint(0, sae.cfg.d_sae, (100,))
examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 100
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens)
    tokens_df["batch"] = i

    flat_tokens = tokens.flatten()

    _, cache = model.run_with_cache(tokens, stop_at_layer = sae.cfg.hook_layer + 1, names_filter = [sae.cfg.hook_name])
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0,1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)

    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")

# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)
더보기

sae_in 
torch.Size([8, 128, 768])
feature_acts 
torch.Size([8, 128, 24576])
fired_mask 
torch.Size([1024])
fired_tokens 
[' to', 'ives', '.', 'Kin', ' composed', ' after', ' practice', '\n', '�', ' guys', ' h', 'uddle', ' awhile', ' defensive', ' line', '\n', ' just', ' getting', ' started', ' Quinn', ' to', ' not', ' to', ' attend', ' OT', 'As', ' because', ' the', ' NFL', ' academic', ' rules', ' for', ' on', ' the', ' quarter', ' system', ' defensive', ' scheme', ' Fac', '\n', ' Quinn', ' I', '�', '�', ' frustration', ' (', ' scheme', '…', 'ley', ' said', '�', ' time', ' pads', ' able', ' there', ' and', ' football', ' again', '�', '\n', ' practice', '\n', ' fast', ' McKin', ' that', ' I', ' fast', 'ling', '\n', ' teammates', ' linebacker', '�', ' there', '�', '\n', '-', 'on', '-', 'ones', ' McKin', ' rush', ' and', ' rush', '\n', ' be', ' NFL', ' more', '�', ' quarterbacks', ' quarterback', ' who', ' his', ' time', ' the', ' line', ' (', 'in', ' but', ' in', ' practice', ' you', ' are', ' Matt', ' Ryan', ' NFL', ' Cam', 'ton', ')', ' Brady', ' plays', ' lined', ' up', '.', ' shoulder', '\n', ' March', ' past', 'ley', '�', ' was', ' one', ' no', ' days', ' off', ',', ' and', ' for', '.', '\n', '�', ' so', '�', ',', ' I', '�', 'm', ' just', ' as', ' hard', ' as', ' I', ' can', '\n', ' at', '\n', ' it', ' it', ' over', ' the', ' it', '�', '�', ',', ' being', ',', '�', ' here', ' and', ' go', ' as', ' hard', ' as', ' I', '<|endoftext|>', 'os', ' Police', ' driver', ' District', ' vehicle']
examples_found 
169

기능 활성화 히스토그램 생성

다음으로, 기능 활성화 히스토그램을 생성하고(위의 대시보드에서 본 것처럼), 방금 생성한 최대 활성화 예시 목록을 표시할 수 있습니다. 우리는 무작위로 선택된 기능 세트 중 첫 번째 기능(인덱스 0)에 대해서만 이 작업을 수행할 것입니다.

feature_acts_df = pd.DataFrame(all_feature_acts.detach().cpu().numpy(), columns = [f"feature_{i}" for i in feature_list])
feature_acts_df.shape
feature_idx = 0
# get non-zero activations

all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][:, feature_idx].detach()
prop_positive_activations = 100*len(all_positive_acts) / (total_batches*batch_size_tokens)

px.histogram(
    all_positive_acts.cpu(),
    nbins=50,
    title=f"Histogram of positive activations - {prop_positive_activations:.3f}% of activations were positive",
    labels={"value": "Activation"},
    width=800,)

더보기

이 그래프는 긍정적 활성화 값의 히스토그램을 보여주고 있으며, 전체 활성화 값 중 **0.506%**가 양수라는 정보도 제공하고 있습니다. 각 항목에 대해 설명하겠습니다.

1. Y축 (count):

  • Y축양수로 활성화된 값들의 개수를 나타냅니다. 여기서 count는 활성화 값이 특정 범위 내에서 몇 번 나타났는지 빈도를 세고 있습니다.

2. X축 (Activation):

  • X축활성화 값을 나타냅니다. 각 활성화 값은 모델의 특징 벡터에서 나온 값을 의미하며, 양수일수록 그 특징이 더 강하게 활성화되었음을 뜻합니다.
  • 예를 들어, 0 근처에서 많은 활성화 값이 존재하고 있으며, 2 이상의 값으로 갈수록 빈도는 줄어듭니다.

3. 그래프의 의미:

  • 대부분의 활성화 값은 0에서 약간 양수 범위에 분포하고 있으며, 강한 활성화(예: 6 이상)는 매우 드뭅니다.
  • 강하게 활성화되는 특징은 상대적으로 적지만, 존재하는 특징이 있을 때 매우 강하게 반응하는 것을 알 수 있습니다.

4. 0.506%의 활성화 값이 양수:

  • 전체 활성화 값 중에서 **0.506%**만이 양수였다는 것을 보여줍니다. 이는 활성화 값의 대부분이 음수이거나 0에 가깝다는 것을 의미합니다. 양수인 활성화는 매우 드문 현상이며, 중요한 특정 패턴에서만 발생하는 것으로 해석할 수 있습니다.

해석 요약:

이 히스토그램은 해당 특징이 특정 상황에서만 양수로 활성화된다는 것을 나타냅니다. 대부분의 경우에는 양수 활성화가 잘 일어나지 않지만, 특정한 상황에서 강하게 반응할 수 있습니다. 활성화 값이 2를 넘는 경우는 상대적으로 드물며, 매우 강한 활성화가 발생하는 사례는 극히 소수에 불과합니다.

top_10_activations = feature_acts_df.sort_values(f"feature_{feature_list[0]}", ascending=False).head(10)
all_token_dfs.iloc[top_10_activations.index] # TODO: double check this is working correctly

 

상위 10개의 로짓 가중치 구하기

마지막 단계로, 상위 10개의 로짓 가중치를 생성하겠습니다. 즉, 우리의 기능 세트에서 각 기능이 가장 강하게 촉진하는 토큰들을 확인할 것입니다.

참고로, 레이어 노름(Layer Norm)을 적용하는 것이 중요합니다. 기본적으로 SAE Lens는 레이어 노름을 적용한 상태로 Transformer를 로드하지만, 때때로 GPU 메모리를 절약하기 위해 전처리를 비활성화하는 경우가 있으며, 이는 로짓 가중치 히스토그램에 약간의 영향을 미칠 수 있습니다.

print(f"Shape of the decoder weights {sae.W_dec.shape})")
print(f"Shape of the model unembed {model.W_U.shape}")
projection_matrix = sae.W_dec @ model.W_U
print(f"Shape of the projection matrix {projection_matrix.shape}")

# then we take the top_k tokens per feature and decode them
top_k = 10
# let's do this for 100 random features
_, top_k_tokens = torch.topk(projection_matrix[feature_list], top_k, dim=1)


feature_df = pd.DataFrame(top_k_tokens.cpu().numpy(), index = [f"feature_{i}" for i in feature_list]).T
feature_df.index = [f"token_{i}" for i in range(top_k)]
feature_df.applymap(lambda x: model.tokenizer.decode(x))

Shape of the decoder weights torch.Size([24576, 768]))
Shape of the model unembed torch.Size([768, 50257])
Shape of the projection matrix torch.Size([24576, 50257])
/tmp/ipykernel_517860/4260783916.py:14: FutureWarning:

DataFrame.applymap has been deprecated. Use DataFrame.map instead.

종합하기: Neuronpedia 대시보드와 비교

이것을 Neuronpedia에서 가져온 대시보드 데이터와 비교하면 어떻게 될까요? 한 번 살펴봅시다.

html = get_dashboard_html(sae_release = "gpt2-small", sae_id=f"{sae.cfg.hook_layer}-res-jb", feature_idx=feature_list[0])
IFrame(html, width=1200, height=600)

복제된 것처럼 보입니다! 이제 대시보드 값이 어떻게 생성되는지 알 수 있습니다.

선택 사항: 동시 발생 네트워크 및 환원 불가능한 부분 공간

우리가 방금 작성한 코드는 "Not All Language Model Features are Linear"에서 일부 분석을 재현하는 데 필요한 코드와 매우 유사하므로, 아래에서는 그들의 멋진 원형 표현을 다시 생성하는 방법을 보여줍니다 (예: 요일과 같은 관련 기능들 간의 기하학적 관계를 시각화).

# only valid for res-jb resid_pre 7. 
# Josh Engel's emailed us these lists. 
day_of_the_week_features = [2592, 4445, 4663, 4733, 6531, 8179, 9566, 20927, 24185]
# months_of_the_year = [3977, 4140, 5993, 7299, 9104, 9401, 10449, 11196, 12661, 14715, 17068, 17528, 19589, 21033, 22043, 23304]
# years_of_10th_century = [1052, 2753, 4427, 6382, 8314, 9576, 9606, 13551, 19734, 20349]

feature_list = day_of_the_week_features

examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 100
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens)
    tokens_df["batch"] = i

    flat_tokens = tokens.flatten()

    _, cache = model.run_with_cache(tokens, stop_at_layer = sae.cfg.hook_layer + 1, names_filter = [sae.cfg.hook_name])
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0,1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)

    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")

# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)
# Using PCA, we can see that these features do indeed lie in a circle!
# do PCA on reconstructions
from sklearn.decomposition import PCA
import plotly.express as px 

pca = PCA(n_components=3)
pca_embedding = pca.fit_transform(all_reconstructions.detach().cpu().numpy())

pca_df = pd.DataFrame(pca_embedding, columns=["PC1", "PC2", "PC3"])
pca_df["tokens"] = all_fired_tokens
pca_df["context"] = all_token_dfs.context.values


px.scatter(
    pca_df, x="PC2", y="PC3",
    hover_data=["context"],
    hover_name="tokens",
    height = 800,
    width = 1200,
    color = "tokens",
    title = "PCA Subspace Reconstructions",
).show()

요일 순서가 올바르게 유지되는 원형 부분 공간을 볼 수 있을 것입니다.

기본 사항: SAE 기능에 대한 개입

기능 조정(Feature Steering)

기능을 찾은 후에 재미있고 때로는 유용한 작업 중 하나는 이를 사용해 모델을 조정하는 것입니다. 이를 위해, 텍스트 집합에서 특정 기능의 최대 활성화를 찾고(위의 활성화 저장소 사용), 이를 기본 스케일로 사용합니다. 그런 다음 이 값을 디코더 가중치에서 추출된 해당 기능을 나타내는 벡터에 곱하고, 마지막으로 우리가 제어하는 매개변수에 곱합니다. 이를 조정하여 텍스트에 미치는 영향을 확인할 수 있습니다. 아래에서는 종종 종교적이거나 철학적인 문장에서 활성화되는 기능(20115)으로 모델을 조정해 보겠습니다. 참고로, 가끔 조정이 GPT-2를 루프에 빠뜨릴 수 있으므로, 여러 번 실행해 보는 것이 좋습니다.

from tqdm import tqdm
from functools import partial 

def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):
    '''
    Find the maximum activation for a given feature index. This is useful for 
    calibrating the right amount of the feature to add.
    '''
    max_activation = 0.0

    pbar = tqdm(range(num_batches))
    for _ in pbar:
        tokens = activation_store.get_batch_tokens()

        _, cache = model.run_with_cache(
            tokens, 
            stop_at_layer=sae.cfg.hook_layer + 1, 
            names_filter=[sae.cfg.hook_name]
        )
        sae_in = cache[sae.cfg.hook_name]
        feature_acts = sae.encode(sae_in).squeeze()

        feature_acts = feature_acts.flatten(0, 1)
        batch_max_activation = feature_acts[:, feature_idx].max().item()
        max_activation = max(max_activation, batch_max_activation)

        pbar.set_description(f"Max activation: {max_activation:.4f}")

    return max_activation

def steering(activations, hook, steering_strength=1.0, steering_vector=None, max_act=1.0):
    # Note if the feature fires anyway, we'd be adding to that here.
    return activations + max_act * steering_strength * steering_vector

def generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=1.0, max_new_tokens=95):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    steering_vector = sae.W_dec[steering_feature].to(model.cfg.device)

    steering_hook = partial(
        steering,
        steering_vector=steering_vector,
        steering_strength=steering_strength,
        max_act=max_act
    )

    # standard transformerlens syntax for a hook context for generation
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos = False if device == "mps" else True,
            prepend_bos = sae.cfg.prepend_bos,
        )

    return model.tokenizer.decode(output[0])

# Choose a feature to steer
steering_feature = steering_feature = 20115  # Choose a feature to steer towards

# Find the maximum activation for this feature
max_act = find_max_activation(model, sae, activation_store, steering_feature)
print(f"Maximum activation for feature {steering_feature}: {max_act:.4f}")

# note we could also get the max activation from Neuronpedia (https://www.neuronpedia.org/api-doc#tag/lookup/GET/api/feature/{modelId}/{layer}/{index})

# Generate text without steering for comparison
prompt = "Once upon a time"
normal_text = model.generate(
    prompt,
    max_new_tokens=95, 
    stop_at_eos = False if device == "mps" else True,
    prepend_bos = sae.cfg.prepend_bos,
)

print("\nNormal text (without steering):")
print(normal_text)

# Generate text with steering
steered_text = generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=2.0)
print("Steered text:")
print(steered_text)

Steered text:
<|endoftext|>Once upon a time, a woman had been created.

A woman, it seemed, was born.

The woman who had been born into the world, the woman who had been born into the world, was born into the world.

She was born into the world, the woman who had been born into the world, was born into the world.

The woman who had been born into the world, the woman who had been born into the world, was born

# Experiment with different steering strengths
print("\nExperimenting with different steering strengths:")
for strength in [-4.0, -2.0, 0.5, 2.0, 4.0]:
    steered_text = generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=strength)
    print(f"\nSteering strength {strength}:")
    print(steered_text)

<|endoftext|>Once upon a time the gods ruled the skies,

And if you want to know,

You must see

The heaven of the gods

And the earth of the gods

And the sea of the gods

And the noblest

I am the king of the gods

And the might of the gods

And the kingdom of the gods

And the kings of the gods

And the prince of the gods

And the

더보기


Experimenting with different steering strengths:
100%|██████████| 95/95 [00:01<00:00, 59.34it/s]

Steering strength -4.0:
<|endoftext|>Once upon a time, there was a situation where a young man had the idea of making a show about the safety of children. He had a different idea. The idea was to teach the kids about the dangers of riding a bicycle, which is an American bicycle safety standard.

The kids were very concerned about the safety of their bicycles. The main reason was that they were riding a bicycle that was not a lot of traction, and that they were getting in and out of their bicycles pretty
100%|██████████| 95/95 [00:01<00:00, 60.29it/s]

Steering strength -2.0:
<|endoftext|>Once upon a time, the price of a car had a huge impact on the marketplace. In the early days of the car, you could buy a new car for a little over $100,000, and still make a good profit. Today, it's around $100,000, and a lot of people still make a decent profit.

But you can't really tell how much the price of a car changed. If you go back and look at the charts, the price of
100%|██████████| 95/95 [00:01<00:00, 60.28it/s]

Steering strength 0.5:
<|endoftext|>Once upon a time, the man that we know as Chris Brown was a little kid. His father was a very good man. His mother was a very good woman. He was the kind of kid who would always tell you, "I'm going to be your dad." That's how he always looked at me.

In that way, he was a very special kid. He was kind of like a bad boy. He was always looking for something to do and wanting to do it
100%|██████████| 95/95 [00:01<00:00, 60.21it/s]

Steering strength 2.0:
<|endoftext|>Once upon a time, the universe was a complex and mysterious place. It was a place where the very things that gave it meaning were lost, forgotten, and buried in the depths of the deep. And now, with the dawn of the new millennium, we find ourselves in a place where we no longer have to be concerned with the past or the future.

But now the universe is changing, and now the universe is changing in ways that we no longer want to be. We are
100%|██████████| 95/95 [00:01<00:00, 60.16it/s]

Steering strength 4.0:
<|endoftext|>Once upon a time there were no kingdoms; they were like one another. But now, as they are, there are none. For they have no god, but only God.

And now they are like one another, and they are as one another:

And now they are like one another:

And now they are like one another:

For they shall not come to know God, but they shall know neither God nor thy God, neither thy God, nor

 

이 출력은 "steering strength"라는 매개변수를 변경하여 GPT-2 모델이 생성한 텍스트를 관찰하는 실험의 결과입니다. Steering strength는 모델이 특정 방향으로 텍스트를 조작하거나 유도할 수 있는 강도를 나타냅니다. 실험 결과는 이 강도를 다양하게 변경하면서 출력된 텍스트가 어떻게 달라지는지 보여줍니다.

세부 설명:

1. Steering Strength -4.0:

  • 내용:이 텍스트는 어린이의 자전거 안전에 관한 내용을 다루고 있습니다. 비교적 구체적이고 일상적인 주제를 다루고 있습니다. 내용이 구체적이고 안전, 규칙 등 실제적인 주제를 이야기하는 방식입니다.
  • Once upon a time, there was a situation where a young man had the idea of making a show about the safety of children. He had a different idea. The idea was to teach the kids about the dangers of riding a bicycle, which is an American bicycle safety standard.
     
  • 해석: Steering strength -4.0은 모델을 특정 방향으로 강하게 유도한 경우입니다. 이 값이 음수일 때, 모델이 일상적이고 구체적인 주제로 집중하게 되는 경향이 있습니다. 이 경우 자전거 안전과 같은 일상적인 주제를 중심으로 이야기가 전개됩니다.

2. Steering Strength -2.0:

  • 내용:이 텍스트는 자동차 시장과 관련된 경제적 내용을 다루고 있습니다. 자동차의 가격 변동과 수익에 대한 이야기가 진행되고 있습니다.
  •  
    Once upon a time, the price of a car had a huge impact on the marketplace. In the early days of the car, you could buy a new car for a little over $100,000, and still make a good profit. Today, it's around $100,000, and a lot of people still make a decent profit.
  • 해석: Steering strength -2.0은 모델이 상대적으로 현실적이고 구체적인 주제에 대해 생성하는 경향이 있지만, -4.0보다 강도가 낮아 더 자유롭습니다. 경제와 관련된 정보나 가격 변동과 같은 좀 더 일반적이고 폭넓은 주제로 확장되었습니다.

3. Steering Strength 0.5:

  • 내용:여기서는 가수 Chris Brown의 어린 시절과 가족 관계에 대한 이야기가 등장합니다. 감정적이고 개인적인 이야기로 내용이 전개됩니다.
  •  
    Once upon a time, the man that we know as Chris Brown was a little kid. His father was a very good man. His mother was a very good woman. He was the kind of kid who would always tell you, "I'm going to be your dad." That's how he always looked at me.
  • 해석: Steering strength 0.5중립적인 강도로 설정되었습니다. 이 경우 모델은 보다 자연스러운 이야기를 생성하는 경향이 있습니다. 특정한 주제에 강하게 유도되지 않고, 비교적 자유롭게 이야기의 흐름을 유지하며 구체적인 인물 이야기를 만들고 있습니다.

4. Steering Strength 2.0:

  • 내용:이 텍스트는 우주의 신비와 변화를 설명하는 좀 더 철학적이고 추상적인 주제를 다룹니다.
  •  
    Once upon a time, the universe was a complex and mysterious place. It was a place where the very things that gave it meaning were lost, forgotten, and buried in the depths of the deep.
  • 해석: Steering strength 2.0은 모델이 추상적이고 철학적인 주제로 향하도록 조정한 결과입니다. 강도가 양수로 높아지면서 구체적인 이야기보다는 보다 광범위하고 개념적인 주제로 전환됩니다. 우주와 같은 추상적인 주제를 다루고 있으며, 현실적인 주제보다 상상력이나 철학적 내용이 더 강조됩니다.

5. Steering Strength 4.0:

  • 내용:이 텍스트는 종교적이고 상징적인 이야기를 다루고 있습니다. 하나님과 왕국에 대한 내용으로, 상당히 추상적이고 심오한 주제입니다.
  •  
    Once upon a time there were no kingdoms; they were like one another. But now, as they are, there are none. For they have no god, but only God.
  • 해석: Steering strength 4.0은 매우 강한 양의 유도로, 매우 추상적이고 상징적인 주제를 다루는 결과를 보여줍니다. 종교적이거나 신화적인 내용을 다루고 있으며, 구체적인 현실보다는 고차원적이고 형이상학적인 개념이 등장합니다.

전반적인 해석:

  • Steering Strength는 모델이 텍스트를 생성할 때 어느 정도 방향성을 가지도록 조작하는 매개변수입니다.
    • 음수의 Steering Strength: 모델이 구체적이고 현실적인 주제를 다루도록 유도됩니다. -4.0에서는 자전거 안전과 같은 매우 구체적이고 실생활에서 흔히 볼 수 있는 이야기가 나왔고, -2.0에서는 자동차 시장과 같은 현실적인 경제적 주제가 나왔습니다.
    • 0.5 (중립적인 Steering Strength): 비교적 일상적인 이야기가 나옵니다. 특정 주제에 너무 치우치지 않고 자연스럽게 이야기를 전개하는 경향이 있습니다.
    • 양수의 Steering Strength: 모델이 추상적이고 철학적인 주제를 다루도록 유도됩니다. 2.0에서는 우주와 관련된 철학적 주제가, 4.0에서는 종교적이고 상징적인 이야기가 나옵니다.

정리:

  • Steering Strength가 음수일수록 모델은 현실적이고 구체적인 이야기를, 양수일수록 추상적이고 철학적인 이야기를 생성하는 경향이 있습니다.
  • 값의 크기가 커질수록 더 극단적인 방향으로 텍스트가 생성됩니다.

우리는 Neuronpedia API를 통해서도 이 작업을 할 수 있으며, 여기 웹사이트에서 할 수도 있습니다. 아래 예시는 매우 특정한 기능을 사용해 몇 개의 토큰만 조정합니다. 이 기능이 무엇을 하고 있는지 알아낼 수 있나요?

import requests
import numpy as np

url = "https://www.neuronpedia.org/api/steer"

payload = {
    # "prompt": "A knight in shining",
    # "prompt": "He had to fight back in self-", 
    "prompt": "In the middle of the universe is the galactic",
    # "prompt": "Oh no. We're running on empty. Its time to fill up the car with",
    # "prompt": "Sure, I'm happy to pay. I don't have any cash on me but let me write you a",
    "modelId": "gpt2-small",
    "features": [
        {
            "modelId": "gpt2-small",
            "layer": "7-res-jb",
            "index": 6770,
            "strength": 8
        }
    ],
    "temperature": 0.2,
    "n_tokens": 2,
    "freq_penalty": 1,
    "seed": np.random.randint(100),
    "strength_multiplier": 4
}
headers = {"Content-Type": "application/json"}

response = requests.post(url, json=payload, headers=headers)

print(response.json())

{'STEERED': 'In the middle of the universe is the galactic centre of', 'DEFAULT': 'In the middle of the universe is the galactic center of', 'id': 'cm0nyvdho0007xoj0f2c5uien', 'shareUrl': 'https://www.neuronpedia.org/steer/cm0nyvdho0007xoj0f2c5uien', 'limit': '57'}

import requests
import numpy as np
url = "https://www.neuronpedia.org/api/steer"

payload = {
    "prompt": "I wrote a letter to my girlfiend. It said \"",
    "modelId": "gpt2-small",
    "features": [
        {
            "modelId": "gpt2-small",
            "layer": "7-res-jb",
            "index": 20115,
            "strength": 4
        }
    ],
    "temperature": 0.7,
    "n_tokens": 120,
    "freq_penalty": 1,
    "seed": np.random.randint(100),
    "strength_multiplier": 4
}
headers = {"Content-Type": "application/json"}

response = requests.post(url, json=payload, headers=headers)

print(response.json())

{'STEERED': 'I wrote a letter to my girlfiend. It said "Sarah, I love you" and that I will always love you and not know her. My life was saved by the strength of the light that binds to her.\n\nI am so thankful for you, Sarah, who gave me so much joy and joy-that I have accepted all your loving hearts that God has given me.\n\nPlease forgive me for what I did and what my faith will do to you / through this time / that which made me man/woman / in Jesus Christ; or let no one have sinned against my name with the sword of God / shall it be', 'DEFAULT': 'I wrote a letter to my girlfiend. It said "Sarah, I love you and I really, really want your help with the development of the game." After he found out that my letter was being used to get me into other projects, he sent me a couple of copies of the game to read through. The first time I ever saw it was during an interview for E3 2011 at PAX East. I was very excited and surprised by what he had done with his top-secret game!\n\nWhen we were talking about making this game, we both knew that if we were going to make it, the best way for us to do it was', 'id': 'clyremdaj000l2k5lxtejxnir', 'shareUrl': 'https://www.neuronpedia.org/steer/clyremdaj000l2k5lxtejxnir', 'limit': '56'}

기능 제거(Feature Ablation)

기능 제거도 살펴볼 가치가 있습니다. 기능 제거는 기능 조정의 특별한 경우로, 기능의 값을 항상 0으로 설정하는 방식입니다.

여기서 우리는 다음을 수행합니다:

  1. 텍스트 생성을 하지 않고, 더 미묘한 차이를 얻기 위해 테스트 프롬프트를 사용합니다.
  2. SAE 기능 활성화에 후크를 연결합니다.
  3. 모든 위치에서 특정 기능의 값을 0으로 설정합니다(기본 기능이 마지막 위치에서 활성화된다는 것을 알고 있습니다).
  4. SAE가 포착하지 못한 정보(오류 항)를 포함할 때, 이 제거가 더 효과적인지 또는 덜 효과적인지 확인합니다.

참고로, 히드라 효과의 존재는 기능 제거 실험에 대한 논리를 어렵게 만들 수 있습니다.

from transformer_lens.utils import test_prompt
from functools import partial

def test_prompt_with_ablation(model, sae, prompt, answer, ablation_features):

    def ablate_feature_hook(feature_activations, hook, feature_ids, position = None):

        if position is None:
            feature_activations[:,:,feature_ids] = 0
        else:
            feature_activations[:,position,feature_ids] = 0

        return feature_activations

    ablation_hook = partial(ablate_feature_hook, feature_ids = ablation_features)

    model.add_sae(sae)
    hook_point = sae.cfg.hook_name + '.hook_sae_acts_post'
    model.add_hook(hook_point, ablation_hook, "fwd")

    test_prompt(prompt, answer, model)

    model.reset_hooks()
    model.reset_saes()

# Example usage in a notebook:

# Assume model and sae are already defined

# Choose a feature to ablate

model.reset_hooks(including_permanent=True)
prompt = "In the beginning, God created the heavens and the"
answer = "earth"
test_prompt(prompt, answer, model)


# Generate text with feature ablation
print("Test Prompt with feature ablation and no error term")
ablation_feature = 16873  # Replace with any feature index you're interested in. We use the religion feature
sae.use_error_term = False
test_prompt_with_ablation(model, sae, prompt, answer, ablation_feature)

print("Test Prompt with feature ablation and error term")
ablation_feature = 16873  # Replace with any feature index you're interested in. We use the religion feature
sae.use_error_term = True
test_prompt_with_ablation(model, sae, prompt, answer, ablation_feature)
더보기

이 코드는 GPT-2 모델에 Sparse Autoencoder (SAE)를 추가하고, 특정 특징(feature)을 "제거(ablating)"한 상태에서 프롬프트에 대해 모델이 어떻게 반응하는지를 실험하는 과정을 다루고 있습니다. Ablation은 모델의 특정 부분, 즉 여기서는 특정 특징(특정 뉴런들)을 끄거나 비활성화하는 방식으로, 모델이 해당 특징 없이도 어떻게 작동하는지를 관찰하는 기법입니다.

주요 목적

이 코드는 특정 특징(feature)을 없애거나 조정한 상태에서 모델이 텍스트를 어떻게 생성하는지 살펴보는 실험입니다. 특히, 종교와 같은 특정 주제를 다루는 뉴런(특징)을 없애면 모델의 출력이 어떻게 달라지는지를 테스트합니다.

코드 세부 설명

1. ablate_feature_hook 함수

def ablate_feature_hook(feature_activations, hook, feature_ids, position=None):
    if position is None:
        feature_activations[:, :, feature_ids] = 0
    else:
        feature_activations[:, position, feature_ids] = 0

    return feature_activations
  • 설명:
    • 이 함수는 특정 특징(feature)을 비활성화하는 역할을 합니다.
    • feature_activations: SAE가 인코딩한 각 토큰의 활성화 값(activation values)을 의미합니다.
    • feature_ids: 비활성화하려는 특정 특징(feature)의 인덱스입니다.
    • position: 비활성화할 위치를 의미합니다. 만약 특정 위치를 지정하지 않으면, 모든 토큰에서 해당 특징을 비활성화합니다.
  • 작동 방식:
    • 특징(특정 뉴런)의 활성화 값을 0으로 만들어 해당 뉴런이 더 이상 기여하지 않도록 만듭니다.
    • 즉, 특정 뉴런을 없앰으로써 모델이 해당 특징을 고려하지 않고 텍스트를 생성하게 합니다.

2. test_prompt_with_ablation 함수

def test_prompt_with_ablation(model, sae, prompt, answer, ablation_features):
    ablation_hook = partial(ablate_feature_hook, feature_ids=ablation_features)

    model.add_sae(sae)
    hook_point = sae.cfg.hook_name + '.hook_sae_acts_post'
    model.add_hook(hook_point, ablation_hook, "fwd")

    test_prompt(prompt, answer, model)

    model.reset_hooks()
    model.reset_saes()
  • 설명:
    • 이 함수는 특정 특징을 제거한 상태에서 모델이 프롬프트에 어떻게 반응하는지 테스트하는 함수입니다.
  • 세부 설명:
    1. ablation_hook는 ablate_feature_hook을 partial로 감싼 것으로, 특정 특징을 비활성화할 후크(hook)를 생성합니다. 여기서 feature_ids는 비활성화할 특징의 인덱스입니다.
    2. SAE 추가: model.add_sae(sae)는 모델에 SAE(Sparse Autoencoder)를 추가합니다. SAE는 모델이 텍스트를 처리할 때 중간에 특정 뉴런(특징)의 활성화 값을 수정할 수 있는 기능을 제공합니다.
    3. 후크 추가: model.add_hook은 모델의 특정 부분에 후크(hook)를 추가하여, 중간 단계에서 활성화 값을 수정하는 과정을 수행합니다.
      • hook_point는 SAE의 활성화 값 중 hook_sae_acts_post라는 위치에서 수정이 이루어지도록 설정합니다.
      • 이 후크는 특정 뉴런(특징)의 활성화 값을 0으로 만들어 해당 뉴런이 모델 예측에 영향을 주지 않게 만듭니다.
    4. test_prompt(prompt, answer, model): 특정 프롬프트와 정답을 가지고 비활성화된 상태에서 모델이 어떻게 반응하는지 테스트합니다.
    5. 후크 및 SAE 초기화: model.reset_hooks()와 model.reset_saes()는 실험 후에 모델에 추가된 후크와 SAE를 초기 상태로 되돌립니다.

3. 예시 사용법

model.reset_hooks(including_permanent=True)
prompt = "In the beginning, God created the heavens and the"
answer = "earth"
test_prompt(prompt, answer, model)

# Generate text with feature ablation
print("Test Prompt with feature ablation and no error term")
ablation_feature = 16873  # Replace with any feature index you're interested in. We use the religion feature
sae.use_error_term = False
test_prompt_with_ablation(model, sae, prompt, answer, ablation_feature)

print("Test Prompt with feature ablation and error term")
ablation_feature = 16873  # Replace with any feature index you're interested in. We use the religion feature
sae.use_error_term = True
test_prompt_with_ablation(model, sae, prompt, answer, ablation_feature)
  • 설명:
    1. 후크 초기화: model.reset_hooks()를 통해 모델에 설정된 모든 후크를 초기화합니다. 실험을 시작하기 전에 깨끗한 상태로 만듭니다.
    2. 기본 텍스트 생성: prompt는 "In the beginning, God created the heavens and the"라는 문장이 주어졌고, 예측되는 단어는 "earth"입니다. 이 프롬프트에 대해 모델이 기본적으로 어떻게 예측하는지를 먼저 확인합니다.
    3. 특징 제거 후 텍스트 생성:
      • 특징 16873 제거: ablation_feature = 16873은 모델이 특정 종교적인 특징을 비활성화하는 실험입니다. 이 특징을 비활성화한 상태에서 모델이 프롬프트에 어떻게 반응하는지 확인합니다.
      • sae.use_error_term = False: 에러 항(term)을 사용하지 않고 실험을 진행합니다. 이 값이 False일 때는 모델이 약간 더 단순한 방식으로 특징을 수정합니다.
      • 다시 test_prompt_with_ablation을 호출하여 이 특징이 제거된 상태에서 모델이 프롬프트에 어떻게 반응하는지 확인합니다.
    4. 에러 항 사용:
      • sae.use_error_term = True: 이번에는 에러 항을 사용하여 특징을 수정합니다. 에러 항은 특징을 제거할 때 더 정밀한 조정이 가능하게 합니다.
      • 동일한 프롬프트에 대해 특정 특징을 제거한 상태에서 모델이 어떻게 반응하는지 테스트합니다.

주요 개념 요약:

  1. 특징 비활성화(Feature Ablation):
    • 특정 뉴런(특징)을 비활성화하여 모델이 해당 특징 없이도 텍스트를 예측하거나 생성하는 방식을 관찰하는 실험입니다. 이 실험을 통해 모델이 어떤 특징에 의존하는지 확인할 수 있습니다.
  2. SAE (Sparse Autoencoder):
    • SAE는 모델의 중간 활성화 값에 **특정 뉴런(특징)**을 수정하거나 제거할 수 있는 기능을 제공합니다. 이 코드는 특정 뉴런의 활성화 값을 0으로 설정하여 모델이 해당 뉴런 없이 텍스트를 생성하도록 만듭니다.
  3. 에러 항 사용:
    • 에러 항을 사용함으로써 모델의 뉴런 비활성화 작업을 더 정교하게 조정할 수 있습니다. 이 항을 사용하면 보다 정밀한 비활성화가 가능해집니다.

이 실험을 통해 특정 주제나 패턴에 해당하는 뉴런을 비활성화하면 모델의 출력이 어떻게 변화하는지를 살펴볼 수 있으며, 이는 모델의 예측 과정에서 특정 뉴런이 어떻게 중요한 역할을 하는지에 대한 통찰을 제공합니다.

 

후크(hook)는 모델의 특정 지점에서 중간 처리 과정을 가로채거나 수정할 수 있는 메커니즘입니다. 즉, 모델이 텍스트를 생성하는 과정의 특정 단계에서 활성화 값을 수정하거나, 원하는 방식으로 간섭할 수 있게 해주는 도구라고 생각하면 됩니다.

이 코드에서 후크가 하는 역할은 다음과 같습니다:

1. 중간 활성화 값 조정

  • 모델이 텍스트를 처리하면서, 각 레이어에서 활성화 값(특정 뉴런의 출력값)을 계산합니다.
  • 후크는 특정 레이어나 단계에서 이 활성화 값을 가로채거나 수정할 수 있습니다.
  • 코드에서는 hook_sae_acts_post라는 지점에서 후크를 추가합니다. 이 지점은 SAE가 특정 뉴런(특징)의 활성화 값을 계산한 이후에 해당합니다.

2. 특정 뉴런의 비활성화 (Ablation)

  • 후크는 활성화 값에 접근하여, 특정 뉴런(특징)의 값을 0으로 설정하는 방식으로 비활성화(ablating)합니다.
  • 이 비활성화는 특정 뉴런이 모델의 최종 출력에 영향을 주지 않도록 만듭니다.
  • 예를 들어, 코드에서 feature_ids에 지정된 뉴런(특징)을 ablation_hook으로 0으로 설정하면, 해당 뉴런이 "꺼져서" 모델이 해당 특징을 더 이상 사용하지 않게 됩니다.

3. 실험 도구로서의 역할

  • 후크는 모델이 특정 뉴런 없이도 제대로 작동하는지, 또는 그 뉴런이 모델의 예측에 얼마나 중요한 역할을 하는지를 실험하는 데 사용됩니다.
  • 이 과정은 특정 주제나 패턴에 해당하는 뉴런이 모델의 예측에 미치는 영향을 분석할 수 있게 합니다.

예를 들어, 종교와 관련된 뉴런(16873번)을 비활성화했을 때 모델이 어떤 반응을 보이는지 관찰할 수 있습니다. 만약 이 후크로 인해 텍스트 생성 방식이 크게 변한다면, 해당 뉴런이 모델의 예측에서 매우 중요한 역할을 했다고 볼 수 있습니다.

요약:

  • 후크는 모델의 특정 지점에서 활성화 값을 가로채고 수정할 수 있는 메커니즘입니다.
  • 이 코드는 후크를 이용해 특정 뉴런(특징)을 비활성화하는 실험을 수행하고, 모델이 해당 뉴런 없이도 텍스트를 생성하는 과정을 관찰합니다.
  • 후크를 통해 모델의 내부 동작을 실시간으로 수정하거나 분석할 수 있으며, 이는 모델 해석특징 분석에 매우 유용한 도구입니다.

Feature Attribution

from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple, Callable

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint


class SaeReconstructionCache(NamedTuple):
    sae_in: torch.Tensor
    feature_acts: torch.Tensor
    sae_out: torch.Tensor
    sae_error: torch.Tensor


def track_grad(tensor: torch.Tensor) -> None:
    """wrapper around requires_grad and retain_grad"""
    tensor.requires_grad_(True)
    tensor.retain_grad()


@dataclass
class ApplySaesAndRunOutput:
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]

    def zero_grad(self) -> None:
        """Helper to zero grad all tensors in this object."""
        self.model_output.grad = None
        for act in self.model_activations.values():
            act.grad = None
        for cache in self.sae_activations.values():
            cache.sae_in.grad = None
            cache.feature_acts.grad = None
            cache.sae_out.grad = None
            cache.sae_error.grad = None


def apply_saes_and_run(
    model: HookedTransformer,
    saes: dict[str, SAE],
    input: Any,
    include_error_term: bool = True,
    track_model_hooks: list[str] | None = None,
    return_type: Literal["logits", "loss"] = "logits",
    track_grads: bool = False,
) -> ApplySaesAndRunOutput:
    """
    Apply the SAEs to the model at the specific hook points, and run the model.
    By default, this will include a SAE error term which guarantees that the SAE
    will not affect model output. This function is designed to work correctly with
    backprop as well, so it can be used for gradient-based feature attribution.

    Args:
        model: the model to run
        saes: the SAEs to apply
        input: the input to the model
        include_error_term: whether to include the SAE error term to ensure the SAE doesn't affect model output. Default True
        track_model_hooks: a list of hook points to record the activations and gradients. Default None
        return_type: this is passed to the model.run_with_hooks function. Default "logits"
        track_grads: whether to track gradients. Default False
    """

    fwd_hooks = []
    bwd_hooks = []

    sae_activations: dict[str, SaeReconstructionCache] = {}
    model_activations: dict[str, torch.Tensor] = {}

    # this hook just track the SAE input, output, features, and error. If `track_grads=True`, it also ensures
    # that requires_grad is set to True and retain_grad is called for intermediate values.
    def reconstruction_hook(sae_in: torch.Tensor, hook: HookPoint, hook_point: str):  # noqa: ARG001
        sae = saes[hook_point]
        feature_acts = sae.encode(sae_in)
        sae_out = sae.decode(feature_acts)
        sae_error = (sae_in - sae_out).detach().clone()
        if track_grads:
            track_grad(sae_error)
            track_grad(sae_out)
            track_grad(feature_acts)
            track_grad(sae_in)
        sae_activations[hook_point] = SaeReconstructionCache(
            sae_in=sae_in,
            feature_acts=feature_acts,
            sae_out=sae_out,
            sae_error=sae_error,
        )

        if include_error_term:
            return sae_out + sae_error
        return sae_out

    def sae_bwd_hook(output_grads: torch.Tensor, hook: HookPoint):  # noqa: ARG001
        # this just passes the output grads to the input, so the SAE gets the same grads despite the error term hackery
        return (output_grads,)

    # this hook just records model activations, and ensures that intermediate activations have gradient tracking turned on if needed
    def tracking_hook(hook_input: torch.Tensor, hook: HookPoint, hook_point: str):  # noqa: ARG001
        model_activations[hook_point] = hook_input
        if track_grads:
            track_grad(hook_input)
        return hook_input

    for hook_point in saes.keys():
        fwd_hooks.append(
            (hook_point, partial(reconstruction_hook, hook_point=hook_point))
        )
        bwd_hooks.append((hook_point, sae_bwd_hook))
    for hook_point in track_model_hooks or []:
        fwd_hooks.append((hook_point, partial(tracking_hook, hook_point=hook_point)))

    # now, just run the model while applying the hooks
    with model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=bwd_hooks):
        model_output = model(input, return_type=return_type)

    return ApplySaesAndRunOutput(
        model_output=model_output,
        model_activations=model_activations,
        sae_activations=sae_activations,
    )
from dataclasses import dataclass
from transformer_lens.hook_points import HookPoint
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal, NamedTuple

import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint

EPS = 1e-8

torch.set_grad_enabled(True)
@dataclass
class AttributionGrads:
    metric: torch.Tensor
    model_output: torch.Tensor
    model_activations: dict[str, torch.Tensor]
    sae_activations: dict[str, SaeReconstructionCache]


@dataclass
class Attribution:
    model_attributions: dict[str, torch.Tensor]
    model_activations: dict[str, torch.Tensor]
    model_grads: dict[str, torch.Tensor]
    sae_feature_attributions: dict[str, torch.Tensor]
    sae_feature_activations: dict[str, torch.Tensor]
    sae_feature_grads: dict[str, torch.Tensor]
    sae_errors_attribution_proportion: dict[str, float]


def calculate_attribution_grads(
    model: HookedSAETransformer,
    prompt: str,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> AttributionGrads:
    """
    Wrapper around apply_saes_and_run that calculates gradients wrt to the metric_fn.
    Tracks grads for both SAE feature and model neurons, and returns them in a structured format.
    """
    output = apply_saes_and_run(
        model,
        saes=include_saes or {},
        input=prompt,
        return_type="logits" if return_logits else "loss",
        track_model_hooks=track_hook_points,
        include_error_term=include_error_term,
        track_grads=True,
    )
    metric = metric_fn(output.model_output)
    output.zero_grad()
    metric.backward()
    return AttributionGrads(
        metric=metric,
        model_output=output.model_output,
        model_activations=output.model_activations,
        sae_activations=output.sae_activations,
    )


def calculate_feature_attribution(
    model: HookedSAETransformer,
    input: Any,
    metric_fn: Callable[[torch.Tensor], torch.Tensor],
    track_hook_points: list[str] | None = None,
    include_saes: dict[str, SAE] | None = None,
    return_logits: bool = True,
    include_error_term: bool = True,
) -> Attribution:
    """
    Calculate feature attribution for SAE features and model neurons following
    the procedure in https://transformer-circuits.pub/2024/march-update/index.html#feature-heads.
    This include the SAE error term by default, so inserting the SAE into the calculation is
    guaranteed to not affect the model output. This can be disabled by setting `include_error_term=False`.

    Args:
        model: The model to calculate feature attribution for.
        input: The input to the model.
        metric_fn: A function that takes the model output and returns a scalar metric.
        track_hook_points: A list of model hook points to track activations for, if desired
        include_saes: A dictionary of SAEs to include in the calculation. The key is the hook point to apply the SAE to.
        return_logits: Whether to return the model logits or loss. This is passed to TLens, so should match whatever the metric_fn expects (probably logits)
        include_error_term: Whether to include the SAE error term in the calculation. This is recommended, as it ensures that the SAE will not affecting the model output.
    """
    # first, calculate gradients wrt to the metric_fn.
    # these will be multiplied with the activation values to get the attributions
    outputs_with_grads = calculate_attribution_grads(
        model,
        input,
        metric_fn,
        track_hook_points,
        include_saes=include_saes,
        return_logits=return_logits,
        include_error_term=include_error_term,
    )
    model_attributions = {}
    model_activations = {}
    model_grads = {}
    sae_feature_attributions = {}
    sae_feature_activations = {}
    sae_feature_grads = {}
    sae_error_proportions = {}
    # this code is long, but all it's doing is multiplying the grads by the activations
    # and recording grads, acts, and attributions in dictionaries to return to the user
    with torch.no_grad():
        for name, act in outputs_with_grads.model_activations.items():
            assert act.grad is not None
            raw_activation = act.detach().clone()
            model_attributions[name] = (act.grad * raw_activation).detach().clone()
            model_activations[name] = raw_activation
            model_grads[name] = act.grad.detach().clone()
        for name, act in outputs_with_grads.sae_activations.items():
            assert act.feature_acts.grad is not None
            assert act.sae_out.grad is not None
            raw_activation = act.feature_acts.detach().clone()
            sae_feature_attributions[name] = (
                (act.feature_acts.grad * raw_activation).detach().clone()
            )
            sae_feature_activations[name] = raw_activation
            sae_feature_grads[name] = act.feature_acts.grad.detach().clone()
            if include_error_term:
                assert act.sae_error.grad is not None
                error_grad_norm = act.sae_error.grad.norm().item()
            else:
                error_grad_norm = 0
            sae_out_norm = act.sae_out.grad.norm().item()
            sae_error_proportions[name] = error_grad_norm / (
                sae_out_norm + error_grad_norm + EPS
            )
        return Attribution(
            model_attributions=model_attributions,
            model_activations=model_activations,
            model_grads=model_grads,
            sae_feature_attributions=sae_feature_attributions,
            sae_feature_activations=sae_feature_activations,
            sae_feature_grads=sae_feature_grads,
            sae_errors_attribution_proportion=sae_error_proportions,
        )


# prompt = " Tiger Woods plays the sport of"
# pos_token = model.tokenizer.encode(" golf")[0]
prompt = "In the beginning, God created the heavens and the"
pos_token = model.tokenizer.encode(" earth")
neg_token = model.tokenizer.encode(" sky")
def metric_fn(logits: torch.tensor, pos_token: torch.tensor =pos_token, neg_token: torch.Tensor=neg_token) -> torch.Tensor:
    return logits[0,-1,pos_token] - logits[0,-1,neg_token]

feature_attribution_df = calculate_feature_attribution(
    input = prompt,
    model = model,
    metric_fn = metric_fn,
    include_saes={sae.cfg.hook_name: sae},
    include_error_term=True,
    return_logits=True,
)
from transformer_lens.utils import test_prompt
test_prompt(prompt, model.to_string(pos_token), model)

tokens = model.to_str_tokens(prompt)
unique_tokens = [f"{i}/{t}" for i, t in enumerate(tokens)]

px.bar(x = unique_tokens,
       y = feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0].sum(-1).detach().cpu().numpy())

%autoreload 0
def convert_sparse_feature_to_long_df(sparse_tensor: torch.Tensor) -> pd.DataFrame:
    """
    Convert a sparse tensor to a long format pandas DataFrame.
    """
    df = pd.DataFrame(sparse_tensor.detach().cpu().numpy())
    df_long = df.melt(ignore_index=False, var_name='column', value_name='value')
    df_long.columns = ["feature", "attribution"]
    df_long_nonzero = df_long[df_long['attribution'] != 0]
    df_long_nonzero = df_long_nonzero.reset_index().rename(columns={'index': 'position'})
    return df_long_nonzero

df_long_nonzero = convert_sparse_feature_to_long_df(feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0])
df_long_nonzero.sort_values("attribution", ascending=False)

 

for i, v in df_long_nonzero.query("position==8").groupby("feature").attribution.sum().sort_values(ascending=False).head(5).items():
    print(f"Feature {i} had a total attribution of {v:.2f}")
    html = get_dashboard_html(sae_release = "gpt2-small", sae_id=f"{sae.cfg.hook_layer}-res-jb", feature_idx=int(i))
    display(IFrame(html, width=1200, height=300))

for i, v in df_long_nonzero.groupby("feature").attribution.sum().sort_values(ascending=False).head(5).items():
    print(f"Feature {i} had a total attribution of {v:.2f}")
    html = get_dashboard_html(sae_release = "gpt2-small", sae_id=f"{sae.cfg.hook_layer}-res-jb", feature_idx=int(i))
    display(IFrame(html, width=1200, height=300))

728x90