RepoMiniMaxMiniMaxpublished Jun 11, 2026seen 1d

MiniMax-AI/MSA

Python

Open original ↗

Captured source

source ↗
published Jun 11, 2026seen 1dcaptured 1dhttp 200method plain

MiniMax-AI/MSA

Language: Python

License: MIT

Stars: 153

Forks: 11

Open issues: 1

Created: 2026-06-11T09:20:28Z

Pushed: 2026-06-11T09:25:53Z

Default branch: main

Fork: no

Archived: no

README:

MiniMax Sparse Attention (MSA)

MSA (fmha_sm100) ships dense FlashAttention and sparse top-k attention kernels for NVIDIA SM100. Two JIT-compiled stacks share one Python package:

![MSA architecture](docs/architecture.png)

> Algorithm reference: [MiniMax Sparse Attention paper](docs/MiniMaxSparseAttention.pdf).

| Stack | Path | What it gives you | |---|---|---| | csrc JIT | python/fmha_sm100/csrc/ | Dense FMHA (fmha_sm100, fmha_sm100_plan) + sparse_topk_select indexer, compiled from Jinja templates by jit.py at runtime. | | CuTe-DSL | python/fmha_sm100/cute/ | Full sparse attention (forward + paged FP8 decode, BF16 / FP8 / NVFP4 / FP4), compiled at runtime via cute.compile. | | Bridge | python/fmha_sm100/sparse_fmha_adapter.py | Adapts the fmha_sm100 API to call sparse_atten_func for sparse prefill paths. |

> License: MIT. Self-authored files carry SPDX-License-Identifier: MIT. > See [LICENSE](LICENSE) and [NOTICE](NOTICE). Bundled / derived third-party > code retains its own license — see [Third-party licenses](#third-party-licenses).

Requirements

  • GPU: NVIDIA SM100.
  • Toolchain: CUDA Toolkit with nvcc on PATH (or CUDA_HOME / CUDA_PATH set).
  • Python: ≥ 3.10.
  • OS: Linux x86_64 (aarch64 untested; JIT builds may need small Makefile edits on WSL).

Quick sanity check before installing:

nvcc --version # expect ≥ 12.x
nvidia-smi --query-gpu=compute_cap --format=csv | grep "10.0" # confirm SM100
python -c "import sys; print(sys.version_info[:2])" # ≥ (3, 10)

Install

# --recursive pulls the NVIDIA CUTLASS submodule (python/fmha_sm100/cutlass/),
# whose headers are required for JIT/AOT compilation.
git clone --recursive https://github.com/MiniMax-AI/MSA.git msa
cd msa
# If you cloned without --recursive:
# git submodule update --init --recursive
pip install . # standard install (works from a wheel too)
# or
pip install -e . # editable install for development

This pulls in the CuTe-DSL stack via nvidia-cutlass-dsl and quack-kernels; the csrc kernels are JIT-compiled at first import from sources shipped inside the package.

Verify

Run a small CUDA smoke test. The first run JIT-compiles `sparse_topk_select`, which takes 30 s – a few minutes on a cold nvcc cache — this is normal, not a hang. Subsequent runs hit the JIT cache and finish in seconds.

python tests/smoke/test_sparse_topk_forced.py

Usage

import torch
from fmha_sm100 import fmha_sm100, fmha_sm100_plan, sparse_topk_select

# Page size and top-k for the sparse prefill path.
page_size, topk = 128, 16

# Dense proxy pass: compute per-block max score from a cheap Q slice.
proxy_plan = fmha_sm100_plan(
qo_lens, kv_lens, proxy_q.shape[1],
num_kv_heads=1,
page_size=page_size,
output_maxscore=True,
)
_, max_score = fmha_sm100(
proxy_q, proxy_k_pages, proxy_v_pages, proxy_plan,
kv_indices=kv_indices,
output_o=False,
output_maxscore=True,
)

# max_score -> sparse KV block indexes.
kv_block_indexes = sparse_topk_select(
max_score.contiguous(), topk, num_valid_pages=num_pages,
)

# Sparse attention with the selected blocks.
sparse_plan = fmha_sm100_plan(
qo_lens, kv_lens, q.shape[1],
num_kv_heads=k_pages.shape[1],
page_size=page_size,
kv_block_num=topk,
)
out, _ = fmha_sm100(
q, k_pages, v_pages, sparse_plan,
kv_indices=kv_indices,
kv_block_indexes=kv_block_indexes,
)

For block-sparse prefill with CSR metadata, the FP4 indexer, NVFP4 K/V, and the paged FP8 decode wrapper, see the CuTe-DSL deep dive:

  • [python/fmha_sm100/cute/README.md](python/fmha_sm100/cute/README.md)

Test

# Fast smoke tests.
python -m pytest tests/smoke -q

# API and end-to-end integration tests.
python -m pytest tests/integration -q
python tests/integration/test_proxy_kv_e2e.py

# Large regression suites.
python tests/regression/test_correctness.py
python tests/regression/test_sparse_attn.py

# CuTe-DSL forward-only sparse attention.
cd python/fmha_sm100/cute
python -m pytest test_sparse_atten.py -q

Benchmark

benchmarks/bench_sparse_attention_ops.py covers dense prefill, paged prefill, sparse prefill, dense decode, paged decode, sparse decode, in fp8 and bf16 (nvfp4 is sparse-prefill only).

python benchmarks/bench_sparse_attention_ops.py --help # full flag list

Common invocations (output is TSV):

| Goal | Command | |---|---| | FP8 full sweep | python benchmarks/bench_sparse_attention_ops.py --dtype fp8 --sections all --output_mode o -o /tmp/msa_fp8.tsv | | BF16 full sweep | python benchmarks/bench_sparse_attention_ops.py --dtype bf16 --sections all --output_mode o -o /tmp/msa_bf16.tsv | | NVFP4 sparse prefill | python benchmarks/bench_sparse_attention_ops.py --dtype nvfp4 --sections sparse_prefill --output_mode o -o /tmp/msa_nvfp4.tsv | | Quick CI smoke | python benchmarks/bench_sparse_attention_ops.py --dtype fp8 --sections prefill,decode,sparse_decode --seqs 8192,16384 --tp 1,4 --decode-k 8192,131072 --decode-b 32 --dry-run-ms 50 --repeat-ms 200 -o /tmp/msa_smoke.tsv | | Output-mode checks (dense/paged) | --output_mode maxscore or --output_mode full |

Layout

python/fmha_sm100/ Python package
__init__.py Public re-exports (lazy for the CuTe-DSL stack)
api.py fmha_sm100 / fmha_sm100_plan / sparse_topk_select
jit.py Runtime JIT (nvcc + ninja) for the csrc stack
sparse.py Lazy shim that loads the cute/ stack
sparse_fmha_adapter.py Bridge: fmha_sm100 API → sparse_atten_func
csrc/ CUDA kernels + Jinja templates (JIT-compiled)
include/ Vendored FlashInfer / CUTLASS-derived / TRT-LLM headers
cutlass/ NVIDIA CUTLASS git submodule (include/ + tools/util/include/)
cute/ CuTe-DSL sparse attention (loaded via sys.path)
tests/ Correctness tests
smoke/ integration/ regression/
scripts/ Warmup + cache-management helpers
benchmarks/ bench_sparse_attention_ops.py

Stacks

  • csrc JIT — dense FlashAttention, page KV, and sparse_topk_select

indexer. Compiled at runtime from csrc/*.cu.jinja plus csrc/include/. Public entry: fmha_sm100.plan → run.

  • CuTe-DSL — block-sparse prefill, FP8 / NVFP4 / FP4 quantization, paged

FP8 decode (SparseDecodePagedAttentionWrapper), FP4 block-score indexer. Public entry: fmha_sm100.sparse_atten_func,...

Excerpt shown — open the source for the full document.

Notability

notability 6.0/10

New repo by MiniMax, moderate traction (153 stars).