MiniMax-AI/MSA
Python
Captured source
source ↗MiniMax-AI/MSA
Language: Python
License: MIT
Stars: 153
Forks: 11
Open issues: 1
Created: 2026-06-11T09:20:28Z
Pushed: 2026-06-11T09:25:53Z
Default branch: main
Fork: no
Archived: no
README:
MiniMax Sparse Attention (MSA)
MSA (fmha_sm100) ships dense FlashAttention and sparse top-k attention kernels for NVIDIA SM100. Two JIT-compiled stacks share one Python package:

> Algorithm reference: [MiniMax Sparse Attention paper](docs/MiniMaxSparseAttention.pdf).
| Stack | Path | What it gives you | |---|---|---| | csrc JIT | python/fmha_sm100/csrc/ | Dense FMHA (fmha_sm100, fmha_sm100_plan) + sparse_topk_select indexer, compiled from Jinja templates by jit.py at runtime. | | CuTe-DSL | python/fmha_sm100/cute/ | Full sparse attention (forward + paged FP8 decode, BF16 / FP8 / NVFP4 / FP4), compiled at runtime via cute.compile. | | Bridge | python/fmha_sm100/sparse_fmha_adapter.py | Adapts the fmha_sm100 API to call sparse_atten_func for sparse prefill paths. |
> License: MIT. Self-authored files carry SPDX-License-Identifier: MIT. > See [LICENSE](LICENSE) and [NOTICE](NOTICE). Bundled / derived third-party > code retains its own license — see [Third-party licenses](#third-party-licenses).
Requirements
- GPU: NVIDIA SM100.
- Toolchain: CUDA Toolkit with
nvcconPATH(orCUDA_HOME/CUDA_PATHset). - Python: ≥ 3.10.
- OS: Linux x86_64 (aarch64 untested; JIT builds may need small Makefile edits on WSL).
Quick sanity check before installing:
nvcc --version # expect ≥ 12.x nvidia-smi --query-gpu=compute_cap --format=csv | grep "10.0" # confirm SM100 python -c "import sys; print(sys.version_info[:2])" # ≥ (3, 10)
Install
# --recursive pulls the NVIDIA CUTLASS submodule (python/fmha_sm100/cutlass/), # whose headers are required for JIT/AOT compilation. git clone --recursive https://github.com/MiniMax-AI/MSA.git msa cd msa # If you cloned without --recursive: # git submodule update --init --recursive pip install . # standard install (works from a wheel too) # or pip install -e . # editable install for development
This pulls in the CuTe-DSL stack via nvidia-cutlass-dsl and quack-kernels; the csrc kernels are JIT-compiled at first import from sources shipped inside the package.
Verify
Run a small CUDA smoke test. The first run JIT-compiles `sparse_topk_select`, which takes 30 s – a few minutes on a cold nvcc cache — this is normal, not a hang. Subsequent runs hit the JIT cache and finish in seconds.
python tests/smoke/test_sparse_topk_forced.py
Usage
import torch from fmha_sm100 import fmha_sm100, fmha_sm100_plan, sparse_topk_select # Page size and top-k for the sparse prefill path. page_size, topk = 128, 16 # Dense proxy pass: compute per-block max score from a cheap Q slice. proxy_plan = fmha_sm100_plan( qo_lens, kv_lens, proxy_q.shape[1], num_kv_heads=1, page_size=page_size, output_maxscore=True, ) _, max_score = fmha_sm100( proxy_q, proxy_k_pages, proxy_v_pages, proxy_plan, kv_indices=kv_indices, output_o=False, output_maxscore=True, ) # max_score -> sparse KV block indexes. kv_block_indexes = sparse_topk_select( max_score.contiguous(), topk, num_valid_pages=num_pages, ) # Sparse attention with the selected blocks. sparse_plan = fmha_sm100_plan( qo_lens, kv_lens, q.shape[1], num_kv_heads=k_pages.shape[1], page_size=page_size, kv_block_num=topk, ) out, _ = fmha_sm100( q, k_pages, v_pages, sparse_plan, kv_indices=kv_indices, kv_block_indexes=kv_block_indexes, )
For block-sparse prefill with CSR metadata, the FP4 indexer, NVFP4 K/V, and the paged FP8 decode wrapper, see the CuTe-DSL deep dive:
- [
python/fmha_sm100/cute/README.md](python/fmha_sm100/cute/README.md)
Test
# Fast smoke tests. python -m pytest tests/smoke -q # API and end-to-end integration tests. python -m pytest tests/integration -q python tests/integration/test_proxy_kv_e2e.py # Large regression suites. python tests/regression/test_correctness.py python tests/regression/test_sparse_attn.py # CuTe-DSL forward-only sparse attention. cd python/fmha_sm100/cute python -m pytest test_sparse_atten.py -q
Benchmark
benchmarks/bench_sparse_attention_ops.py covers dense prefill, paged prefill, sparse prefill, dense decode, paged decode, sparse decode, in fp8 and bf16 (nvfp4 is sparse-prefill only).
python benchmarks/bench_sparse_attention_ops.py --help # full flag list
Common invocations (output is TSV):
| Goal | Command | |---|---| | FP8 full sweep | python benchmarks/bench_sparse_attention_ops.py --dtype fp8 --sections all --output_mode o -o /tmp/msa_fp8.tsv | | BF16 full sweep | python benchmarks/bench_sparse_attention_ops.py --dtype bf16 --sections all --output_mode o -o /tmp/msa_bf16.tsv | | NVFP4 sparse prefill | python benchmarks/bench_sparse_attention_ops.py --dtype nvfp4 --sections sparse_prefill --output_mode o -o /tmp/msa_nvfp4.tsv | | Quick CI smoke | python benchmarks/bench_sparse_attention_ops.py --dtype fp8 --sections prefill,decode,sparse_decode --seqs 8192,16384 --tp 1,4 --decode-k 8192,131072 --decode-b 32 --dry-run-ms 50 --repeat-ms 200 -o /tmp/msa_smoke.tsv | | Output-mode checks (dense/paged) | --output_mode maxscore or --output_mode full |
Layout
python/fmha_sm100/ Python package __init__.py Public re-exports (lazy for the CuTe-DSL stack) api.py fmha_sm100 / fmha_sm100_plan / sparse_topk_select jit.py Runtime JIT (nvcc + ninja) for the csrc stack sparse.py Lazy shim that loads the cute/ stack sparse_fmha_adapter.py Bridge: fmha_sm100 API → sparse_atten_func csrc/ CUDA kernels + Jinja templates (JIT-compiled) include/ Vendored FlashInfer / CUTLASS-derived / TRT-LLM headers cutlass/ NVIDIA CUTLASS git submodule (include/ + tools/util/include/) cute/ CuTe-DSL sparse attention (loaded via sys.path) tests/ Correctness tests smoke/ integration/ regression/ scripts/ Warmup + cache-management helpers benchmarks/ bench_sparse_attention_ops.py
Stacks
- csrc JIT — dense FlashAttention, page KV, and
sparse_topk_select
indexer. Compiled at runtime from csrc/*.cu.jinja plus csrc/include/. Public entry: fmha_sm100.plan → run.
- CuTe-DSL — block-sparse prefill, FP8 / NVFP4 / FP4 quantization, paged
FP8 decode (SparseDecodePagedAttentionWrapper), FP4 block-score indexer. Public entry: fmha_sm100.sparse_atten_func,...
Excerpt shown — open the source for the full document.
Notability
notability 6.0/10New repo by MiniMax, moderate traction (153 stars).