profile pic
⌘ '
raccourcis clavier

abbrev: SAE

see also: landspace

Often contains one layers of MLP with few linear ReLU that is trained on a subset of datasets the main LLMs is trained on.

empirical example: if we wish to interpret all features related to the author Camus, we might want to train an SAEs based on all given text of Camus to interpret “similar” features from Llama-3.1

definition

We wish to decompose a models’ activation xRnx \in \mathbb{R}^n into sparse, linear combination of feature directions:

xx0+i=1Mfi(x)didiMn: latent unit-norm feature directionfi(x)0: corresponding feature activation for x\begin{aligned} x \sim x_{0} + &\sum_{i=1}^{M} f_i(x) d_i \\[8pt] \because \quad &d_i M \gg n:\text{ latent unit-norm feature direction} \\ &f_i(x) \ge 0: \text{ corresponding feature activation for }x \end{aligned}

Thus, the baseline architecture of SAEs is a linear autoencoder with L1 penalty on the activations:

f(x)ReLU(Wenc(xbdec)+benc)x^(f)Wdecf(x)+bdec\begin{aligned} f(x) &\coloneqq \text{ReLU}(W_\text{enc}(x - b_\text{dec}) + b_\text{enc}) \\ \hat{x}(f) &\coloneqq W_\text{dec} f(x) + b_\text{dec} \end{aligned}

training it to reconstruct a large dataset of model activations xDx \sim \mathcal{D}, constraining hidden representation ff to be sparse

L1 norm with coefficient λ\lambda to construct loss during training:

L(x)xx^(f(x))22+λf(x)1xx^(f(x))22: reconstruction loss\begin{aligned} \mathcal{L}(x) &\coloneqq \| x-\hat{x}(f(x)) \|_2^2 + \lambda \| f(x) \|_1 \\[8pt] &\because \|x-\hat{x}(f(x)) \|_2^2 : \text{ reconstruction loss} \end{aligned}

intuition

We need to reconstruction fidelity at a given sparsity level, as measured by L0 via a mixture of reconstruction fidelity and L1 regularization.

We can reduce sparsity loss term without affecting reconstruction by scaling up norm of decoder weights, or constraining norms of columns WdecW_\text{dec} during training

Ideas: output of decoder f(x)f(x) has two roles

  • detects what features acre active L1 is crucial to ensure sparsity in decomposition
  • estimates magnitudes of active features L1 is unwanted bias

Gated SAE

uses Pareto improvement over training to reduce L1 penalty (Rajamanoharan et al., 2024)

Clear consequence of the bias during training is shrinkage (Sharkey, 2024) 1

Idea is to use gated ReLU encoder (Dauphin et al., 2017; Shazeer, 2020):

f~(x)1[(Wgate(xbdec)+bgate)>0πgate(x)]fgate(x)ReLU(Wmag(xbdec)+bmag)fmag(x)\tilde{f}(\mathbf{x}) \coloneqq \underbrace{\mathbb{1}[\underbrace{(\mathbf{W}_{\text{gate}}(\mathbf{x} - \mathbf{b}_{\text{dec}}) + \mathbf{b}_{\text{gate}}) > 0}_{\pi_{\text{gate}}(\mathbf{x})}]}_{f_{\text{gate}}(\mathbf{x})} \odot \underbrace{\text{ReLU}(\mathbf{W}_{\text{mag}}(\mathbf{x} - \mathbf{b}_{\text{dec}}) + \mathbf{b}_{\text{mag}})}_{f_{\text{mag}}(\mathbf{x})}

where 1[>0]\mathbb{1}[\bullet > 0] is the (point-wise) Heaviside step function and \odot denotes element-wise multiplication.

termannotations
fgatef_\text{gate}which features are deemed to be active
fmagf_\text{mag}feature activation magnitudes (for features that have been deemed to be active)
πgate(x)\pi_\text{gate}(x)fgatef_\text{gate} sub-layer’s pre-activations

to negate the increases in parameters, use weight sharing:

Scale WmagW_\text{mag} in terms of WgateW_\text{gate} with a vector-valued rescaling parameter rmagRMr_\text{mag} \in \mathbb{R}^M:

(Wmag)ij(exp(rmag))i(Wgate)ij(W_\text{mag})_{ij} \coloneqq (\exp (r_\text{mag}))_i \cdot (W_\text{gate})_{ij}

Figure 3: Gated SAE with weight sharing between gating and magnitude paths

Figure 4: A gated encoder become a single layer linear encoder with JumpReLU (Erichson et al., 2019) activation function σθ\sigma_\theta

feature suppression

See also: link

Loss function of SAEs combines a MSE reconstruction loss with sparsity term:

L(x,f(x),y)=yx2/d+cf(x)d: dimensionality of x\begin{aligned} L(x, f(x), y) &= \|y-x\|^2/d + c\mid f(x) \mid \\[8pt] &\because d: \text{ dimensionality of }x \end{aligned}

the reconstruction is not perfect, given that only one is reconstruction. For smaller value of f(x)f(x), features will be suppressed

How do we fix feature suppression in training SAEs?

introduce element-wise scaling factor per feature in-between encoder and decoder, represented by vector ss:

f(x)=ReLU(Wex+be)fs(x)=sf(x)y=Wdfs(x)+bd\begin{aligned} f(x) &= \text{ReLU}(W_e x + b_e) \\ f_s(x) &= s \odot f(x) \\ y &= W_d f_s(x) + b_d \end{aligned}

Bibliographie

  • Dauphin, Y. N., Fan, A., Auli, M., & Grangier, D. (2017). Language Modeling with Gated Convolutional Networks. arXiv preprint arXiv:1612.08083 [arxiv]
  • Erichson, N. B., Yao, Z., & Mahoney, M. W. (2019). JumpReLU: A Retrofit Defense Strategy for Adversarial Attacks. arXiv preprint arXiv:1904.03750 [arxiv]
  • Rajamanoharan, S., Conmy, A., Smith, L., Lieberum, T., Varma, V., Kramár, J., Shah, R., & Nanda, N. (2024). Improving Dictionary Learning with Gated Sparse Autoencoders. arXiv preprint arXiv:2404.16014 [arxiv]
  • Sharkey, L. (2024). Addressing Feature Suppression in SAEs. AI Alignment Forum. [post]
  • Shazeer, N. (2020). GLU Variants Improve Transformer. arXiv preprint arXiv:2002.05202 [arxiv]

Remarque

  1. If we hold x^()\hat{x}(\bullet) fixed, thus L1 pushes f(x)0f(x) \to 0, while reconstruction loss pushes f(x)f(x) high enough to produce accurate reconstruction.

    An optimal value is somewhere between.

    However, rescaling the shrink feature activations (Sharkey, 2024) is not necessarily enough to overcome bias induced by L1: a SAE might learnt sub-optimal encoder and decoder directions that is not improved by the fixed.