Identifying Functionally Important Features with End-to-End Sparse Dictionary Learning

E2E SAE

Previous works on Sparse Autoencoder have shown its amazing interpretability of neural networks, but current SAEs focus on minimizing the Mean Squared Reconstruction Error (MSRE), which assumes activations are tied directly to performance. For large models, however, many features are redundant or indirectly related, so MSRE ends up optimizing irrelevant features.

The key question: How can we select functionally important features—the ones truly responsible for the network's behavior?

To answer this, the authors propose training methods that still use a sparsity penalty on the activations but replace pure reconstruction error with a KL‐divergence between the original logits and the logits obtained when passing SAE‐reconstructed activations back through the network.

🗒️ Contribution & Findings

Formulations

☁ Transformer Part

Considering a decoder‐only Transformer (e.g. GPT2):

\[ a^{(0)}(x) = x, \quad a^{(l)}(x) = f^{(l)}\bigl(a^{(l-1)}(x)\bigr)\ \text{for}\ l=1,\dots,L-1, \quad y = \mathrm{softmax}\bigl(f^{(L)}(a^{(L-1)}(x))\bigr). \]

☁ Encoder & SAE Part

After extracting each layer's activation:

\[ \mathrm{Enc}\bigl(a^{(l)}(x)\bigr) = \mathrm{ReLU}\bigl(W_e\,a^{(l)}(x) + b_e\bigr), \quad \mathrm{SAE}\bigl(a^{(l)}(x)\bigr) = D^\top\,\mathrm{Enc}\bigl(a^{(l)}(x)\bigr) + b_d. \]

Training

☁ Method 1: Baseline with \(L_{\mathrm{local}}\)

Train SAElocal by minimizing:

\[ L_{\mathrm{local}} = L_{\mathrm{reconstruction}} + \phi\,L_{\mathrm{sparsity}} = \|a^{(l)}(x) - \mathrm{SAE}_{\mathrm{local}}(a^{(l)}(x))\|_2^2 + \phi\,\|\mathrm{Enc}(a^{(l)}(x))\|_1. \]

☁ Method 2: E2E SAE with \(L_{\mathrm{e2e}}\)

Feed reconstructed activation through the remaining layers and compare logits:

\[ \begin{aligned} \hat a^{(l)}(x) &= \mathrm{SAE}_{\mathrm{e2e}}(a^{(l)}(x)),\\ \hat a^{(k)}(x) &= f^{(k)}\bigl(\hat a^{(l)}(x)\bigr)\quad (k = l,\dots,L-1),\\ \hat y &= \mathrm{softmax}\bigl(f^{(L)}(\hat a^{(L-1)}(x))\bigr). \end{aligned} \]

Then minimize KL divergence \(D_{\mathrm{KL}}(y\|\hat y)\).

☁ Method 3: E2E SAE with \(L_{\mathrm{e2e+downstream}}\)

Adds a downstream reconstruction penalty on all later layers:

\[ L_{\mathrm{e2e+downstream}} = L_{\mathrm{KL}} + \phi\,\|\mathrm{Enc}(a^{(l)}(x))\|_1 + \frac{\beta_l}{L - l}\sum_{k=l+1}^{L-1} \bigl\|\hat a^{(k)}(x) - a^{(k)}(x)\bigr\|_2^2. \]

Findings

Finding 1: E2E and E2E+downstream require far fewer features to achieve comparable cross‐entropy loss (Pareto efficiency).

Finding 2: In later layers, E2E+downstream converges to a similar trade‐off as local SAE but with higher efficiency—highlighting the interpretability‐accuracy balance.

Finding 3: Cosine‐similarity analysis shows bimodal clustering for E2E+downstream, indicating it captures both functional and dataset‐specific variability.

Reference: Braun, Dan, et al. "Identifying Functionally Important Features with End-to-End Sparse Dictionary Learning." arXiv:2405.12241 (2024).