OpenAI Trains Weight-Sparse Transformers to Reveal Compact, Interpretable Circuits
'OpenAI enforces extreme weight sparsity in GPT-2 style transformers to recover compact, verifiable circuits that explain specific model behaviors on algorithmic Python tasks.'
OpenAI researchers trained transformer language models with extreme weight sparsity so that the internal mechanisms behind specific behaviors can be isolated and interpreted as small, explicit circuits.
Making transformers weight-sparse
Most transformer language models are dense: neurons read from and write to many residual channels, and features are often stored in superposition. That design makes circuit level analysis hard. Instead of applying sparse post hoc methods, the team trains GPT-2 style decoder-only transformers with an enforced weight sparsity constraint during training.
After each AdamW optimizer step they keep only the largest magnitude entries in every weight matrix and bias, including token embeddings, and zero out the rest. An annealing schedule gradually reduces the fraction of nonzero parameters until the model reaches a target sparsity. In extreme experiments roughly 1 in 1000 weights remains nonzero, and activations are also sparse, with about 1 in 4 activations nonzero at a typical node. The result is a very thin effective connectivity graph that encourages disentangled features aligned to residual channels.
Measuring interpretability via task-specific pruning
To move beyond qualitative examples, the researchers evaluate interpretability with a suite of simple algorithmic Python next-token prediction tasks. Examples include single_double_quote, which asks the model to close a Python string with the correct quote character, and set_or_string, which requires choosing between .add and += depending on whether a variable is a set or a string.
For each task they search for the smallest subnetwork or circuit that still achieves a fixed loss threshold. Pruning is node based: a node can be an MLP neuron at a particular layer, an attention head, or a residual channel. Pruned nodes are mean-ablated, replacing their activation with the mean over the pretraining distribution. The search uses continuous mask parameters with a Heaviside-style gate optimized via a straight-through estimator surrogate gradient. Circuit complexity is measured by the count of active edges between retained nodes, and the main interpretability metric is the geometric mean of edge counts across all tasks.
Concrete, reverse-engineerable circuits
On tasks like matching quote characters, the sparse models yield compact, fully interpretable circuits. In one example an early MLP neuron acts as a quote detector, a second neuron classifies quote type, and a later attention head attends back to the opening quote and copies its type to the closing position. The operational circuit consists of just a few residual channels, two MLP neurons in layer 0, and one attention head with a single query key channel and a single value channel. That subgraph alone is both sufficient and necessary to solve the task in the paper's testing framework.
For more complex behaviors, such as tracking a variable's type across a function, the recovered circuits are larger and only partially understood. Still, the researchers show relatively small graphs where one attention writes the variable name into a token at definition and another attention copies type information back into a later use, yielding a compact mechanistic picture.
Sparsity improves interpretability at modest cost
At matched pretraining loss levels, weight-sparse models require circuits roughly 16 times smaller than those recovered from dense baselines. This defines a capability versus interpretability frontier where increased sparsity improves the ease of mechanistic analysis while incurring some drop in raw capability. The trained models are small and inefficient from a performance perspective, but they produce clear connectivity graphs with measurable edge counts and rigorous sufficiency and necessity tests.
Why this matters
By enforcing sparsity as a training-time design choice, the work turns abstruse discussions of circuits into concrete, testable graphs that can aid safety audits, debugging workflows, and mechanistic research. Treating interpretability as a first-class constraint could make future models easier to inspect and verify, even if the approach is not yet optimized for production efficiency.
For the full details see the paper and supplementary materials linked by the authors.
Сменить язык
Читать эту статью на русском