E2E SAE
Previous works on Sparse Autoencoder have shown its amazing interpretability of neural networks, but current SAEs focus on minimizing the Mean Squared Reconstruction Error (MSRE), which assumes activations are tied directly to performance. For large models, however, many features are redundant or indirectly related, so MSRE ends up optimizing irrelevant features.
The key question: How can we select functionally important features—the ones truly responsible for the network's behavior?
To answer this, the authors propose training methods that still use a sparsity penalty on the activations but replace pure reconstruction error with a KL‐divergence between the original logits and the logits obtained when passing SAE‐reconstructed activations back through the network.
🗒️ Contribution & Findings
- Introduces a novel training objective minimizing KL divergence between original model outputs and SAE‐reconstructed outputs.
- Ensures learned features are tied directly to the model's functional behavior.
- Shows E2E SAEs explain more network performance with fewer total and per‐sample features.
- Bridges the gap between interpretability and performance explanation in neural networks.
Formulations
☁ Transformer Part
Considering a decoder‐only Transformer (e.g. GPT2):
☁ Encoder & SAE Part
After extracting each layer's activation:
Training
☁ Method 1: Baseline with \(L_{\mathrm{local}}\)
Train SAElocal by minimizing:
☁ Method 2: E2E SAE with \(L_{\mathrm{e2e}}\)
Feed reconstructed activation through the remaining layers and compare logits:
Then minimize KL divergence \(D_{\mathrm{KL}}(y\|\hat y)\).
☁ Method 3: E2E SAE with \(L_{\mathrm{e2e+downstream}}\)
Adds a downstream reconstruction penalty on all later layers:
Findings
Finding 1: E2E and E2E+downstream require far fewer features to achieve comparable cross‐entropy loss (Pareto efficiency).
Finding 2: In later layers, E2E+downstream converges to a similar trade‐off as local SAE but with higher efficiency—highlighting the interpretability‐accuracy balance.
Finding 3: Cosine‐similarity analysis shows bimodal clustering for E2E+downstream, indicating it captures both functional and dataset‐specific variability.