profile pic
⌘ '
raccourcis clavier

maturity

a research preview from Anthroppic and this is pretty much still a work in progress

see also reproduction on Gemma 2B and github

A variant of sparse autoencoder where it reads and writes to multiple layers (Lindsey et al., 2024)

Crosscoders produces shared features across layers and even models

motivations

Resolve:

  • cross-layer features: resolve cross-layer superposition

  • circuit simplification: remove redundant features from analysis and enable jumping across training many uninteresting identity circuit connections

  • model diffing: produce shared sets of features across models. This also introduce one model across training, and also completely independent models with different architectures.

cross-layer superposition

given the additive properties of transformers' residual stream, adjacent layers in larger transformers can be thought as "almost parallel"
given the additive properties of transformers' residual stream, adjacent layers in larger transformers can be thought as "almost parallel"

if we think of adjacent layers as being “almost parallel branches that potentially have superposition between them”, then we can apply dictionary learning jointly 1

persistent features and complexity

Current drawbacks of sparse autoencoders is that we have to train it against certain activations layers to extract features. In terms of the residual stream per layers, we end up having lots of duplicate features across layers.

Crosscoders can simplify the circuit given that we use an appropriate architecture 2

The motivation is that some features are persistent across residual stream, which means there will be duplication where the SAEs learn it multiple times

setup.

Autoencoders and transcoders as special cases of crosscoders.

  • autoencoders: reads and predict the same layers
  • transcoders: read from layer nn and predict layer n+1n+1

Crosscoder read/write to many layers, subject to causality constraints.

crosscoders

Let one compute the vector of feature activation f(xj)f_(x_j) on data point xjx_j by summing over contributions of activations of different layers al(xj)a^l(x_j) for layers lLl \in L:

f(xj)=ReLU(lLWenclal(xj)+benc)Wencl: encoder weights at layer lal(xj): activation on datapoint xj at layer l\begin{aligned} f(x_j) &= \text{ReLU}(\sum_{l\in L}W_{\text{enc}}^l a^l(x_j) + b_{\text{enc}}) \\[8pt] &\because W^l_{\text{enc}} : \text{ encoder weights at layer } l \\[8pt] &\because a^l(x_j) : \text{ activation on datapoint } x_j \text{ at layer } l \\ \end{aligned}

We have loss

L=lLal(xj)al(xj)2+lLifi(xj)Wdec,ilL = \sum_{l\in L} \|a^l(x_j) - a^{l^{'}}(x_j)\|^2 + \sum_{l\in L}\sum_i f_i(x_j) \|W^l_{\text{dec,i}}\|

and regularization can be rewritten as:

lLifi(xj)Wdec,il=ifi(xj)(lLWdec,il)\sum_{l\in L}\sum_{i} f_i(x_j) \|W^l_{\text{dec,i}}\| = \sum_{i} f_i(x_j)(\displaystyle\sum_{l \in L} \|W^l_\text{dec,i}\|)

weight of L1 regularization penalty by L1 norm of per-layer decoder weight norms lLWdec,il\sum\limits{l\in L} \|W^l_\text{dec,i}\| 3

We use L1 due to

  • baseline loss comparison: L2 exhibits lower loss than sum of per-layer SAE losses, as they would effectively obtain a loss “bonus” by spreading features across layers

  • layer-wise sparsity surfaces layer-specific features: based on empirical results of model diffing, that L1 uncovers a mix of shared and model-specific features, whereas L2 tends to uncover only shared features.

variants

good to explore:

  1. strictly causal crosscoders to capture MLP computation and treat computation performed by attention layers as linear
  2. combine strictly causal crosscoders for MLP outputs without weakly causal crosscoders for attention outputs
  3. interpretable attention replacement layers that could be used in combination with strictly causal crosscoders for a “replacement model”

Cross-layer Features

How can we discover cross-layer structure?

  • trained a global, acausal crosscoder on residual stream activations of 18-layer models
  • versus 18 SAEs trained on each residual stream layers
  • fixed L1 coefficient for sparsity penalty
  • MSE + decoder norm-weighed L1 norm

model diffing

see also: model stiching and SVCCA

(Laakso & Cottrell, 2000) proposes compare representations by transforming into representations of distances between data points. 4

questions

How do features change over model training? When do they form?

As we make a model wider, do we get more features? or they are largely the same, packed less densely?

Bibliographie

  • Gorton, L. (2024). The Missing Curve Detectors of InceptionV1: Applying Sparse Autoencoders to InceptionV1 Early Vision. arXiv preprint arXiv:2406.03662 [arxiv]
  • Laakso, A., & Cottrell, G. (2000). Content and cluster analysis: Assessing representational similarity in neural systems. Philosophical Psychology, 13(1), 47–76. https://doi.org/10.1080/09515080050002726
  • Lindsey, J., Templeton, A., Marcus, J., Conerly, T., Batson, J., & Olah, C. (2024). Sparse Crosscoders for Cross-Layer Features and Model Diffing. Transformer Circuits Thread. [link]

Remarque

  1. (Gorton, 2024) denotes that cross-branch superposition is significant in interpreting models with parallel branches (InceptionV1)

  2. causal description it provides likely differs from that of the underlying model.

  3. Wdec,il\|W_\text{dec,i}^l\| is the L2 norm of a single feature’s decoder vector at a given layer.

    In principe, one might have expected to use L2 norm of per-layer norm lLWdec,il2\sqrt{\sum_{l \in L} \|W_\text{dec,i}^l\|^2}

  4. Chris Colah’s blog post explains how t-SNE can be used to visualize collections of networks in a function space.