openai/circuit_sparsity
Python
Captured source
source ↗openai/circuit_sparsity
Description: Open-source release accompanying Gao et al. 2025
Language: Python
License: Apache-2.0
Stars: 525
Forks: 53
Open issues: 2
Created: 2025-10-30T21:06:34Z
Pushed: 2025-12-11T23:51:46Z
Default branch: main
Fork: no
Archived: no
README:
Circuit Sparsity Visualizer and Models
Tools for inspecting sparse circuit models from Gao et al. 2025. Provides code for running inference as well as a Streamlit dashboard that allows you to interact with task-specific circuits found by pruning. Note: this README was AI-generated and lightly edited.
Installation
pip install -e .
Launching the Visualizer
Start the Streamlit app from the project root:
streamlit run circuit_sparsity/viz.py
The app loads data from the openaipublic webpage and caches locally. When the visualizer loads you can choose a model, dataset, pruning sweep, and node budget k using the controls in the left column. The plots are rendered with Plotly; most elements are interactive and support hover/click exploration.
Example view of the Streamlit circuit visualizer (wte/wpe tab) with node ablation deltas and activation previews:

Running Model Forward Passes
Transformer definitions live in circuit_sparsity.inference.gpt. The module exports:
GPTConfig/GPT: lightweight GPT implementation suitable for CPU/GPU
inference.
load_model(model_dir, cuda=False): convenience loader that expects the
beeg_config.json and final_model.pt pair found in models/....
Example usage (adapted from tests/test_gpt.py):
from circuit_sparsity.inference.gpt import GPT, GPTConfig, load_model
from circuit_sparsity.inference.hook_utils import hook_recorder
from circuit_sparsity.registries import MODEL_BASE_DIR
config = GPTConfig(block_size=8, vocab_size=16, n_layer=1, n_head=1, d_model=8)
model = GPT(config)
logits, loss, _ = model(idx, targets=targets)
# to get activations
with hook_recorder() as rec:
model(idx)
# rec is a dict that looks like {"0.attn.act_in": tensor(...), ...}
pretrained = load_model(f"{MODEL_BASE_DIR}/models/", cuda=False)Run tests with:
pytest tests/test_gpt.py
Data Layout
Project assets live under https://openaipublic.blob.core.windows.net/circuit-sparsity with the following structure:
models//beeg_config.json: serializedGPTConfigused to rebuild the model.final_model.pt: checkpoint used bycircuit_sparsity.inference.gpt.load_model.viz//////viz_data.pkl: primary payload loaded byviz.py(contains circuit masks,
activations, samples, importances, etc.).
- Additional per-run outputs (masks, histograms, sample buckets) are stored
under the same tree when produced by the preprocessing scripts.
train_curves//progress.json: training metrics consumed by
the dashboard’s summary table.
- Other experiment-specific directories (for example
csp_yolo1/, csp_yolo2/) hold raw artifacts produced while preparing pruning runs.
The file paths surfaced in viz.py and registries.py assume this layout. Update registries.py if you relocate the data.
Models
We release all of the models used to obtain the results in the paper. See registries.py for a list of all models. Exact training hyperparameters can be found in [todo]
csp_yolo1: This is the model used in thesingle_double_quotequalitative results. This is a 118M total param model. This is a somewhat older model that was trained with methods not exactly the same as in the paper; in particular, a method for training with multiple L0 values at the same time.csp_yolo2: This is the model used in thebracket_countingandset_or_string_fixedvarnamequalitative results. This is a 475M total param model.csp_sweep1_*: These models are used to obtain the figure 3 results. The name indicates the model size (in terms of ``expansion factor'' relative to an arbitrary baseline size), weight L0, and activation sparsity level (afrac).csp_bridge1: The bridge model used to obtain the results in the paper.csp_bridge2: Another bridge model.dense1_1x: A dense model trained on our dataset.dense1_2x: A dense model trained on our dataset. 2x wider.dense1_4x: A dense model trained on our dataset. 4x wider.
Sweep ids
prune_v2: 256 CARBS iters, bs=16, very old (unpublished) pruning algorithm. targeting fixedkrather than fixed target lossprune_v3: 256 CARBS iters, bs=64, epochs=32, old (unpublished) algorithm. targeting fixed target lossprune_v4: 768 CARBS iters, bs=64, epochs=48, published algorithm. targeting fixed target lossprune_v5_logitscaling: 256 CARBS iters, bs=32, epochs=32, published algorithm with logit scaling. targeting fixed target loss
Additional Utilities
per_token_viz_demo.py: minimal examples for token-level visualizations.clear_cache.py: deletes locally cached copies of blobstore files (Streamlit/viz caches and the tiktoken cache); run if you need to re-fetch fresh artifacts.
The project relies on Streamlit, Plotly, matplotlib, seaborn, and torch (see pyproject.toml for the full dependency list).
Notability
notability 6.0/10Solid new repo with moderate stars