novitalabs/DeepGEMM_swap
forked from Wangzheee/DeepGEMM
Captured source
source ↗novitalabs/DeepGEMM_swap
Description: DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling
License: MIT
Stars: 0
Forks: 0
Open issues: 0
Created: 2025-10-27T08:42:47Z
Pushed: 2025-09-15T13:33:04Z
Default branch: main
Fork: yes
Parent repository: Wangzheee/DeepGEMM
Archived: no
README:
DeepGEMM
DeepGEMM is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) with fine-grained scaling, as proposed in DeepSeek-V3. It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.
Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from CUTLASS and CuTe, it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques.
Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.
News
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See #95 for details.
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See #94 for details. Please use
DG_JIT_USE_NVRTC=1to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to 1550 TFLOPS on H800! See #74, #78, #81, #86 and 340d988 for details.
Roadmap
- [x] More correctness tests for grouped-contiguous layout
- [x] Shared memory swizzling for output
- [ ] Larger block size on N (up to 256)
- [x] MoE scheduler with TMA multicast compatibility
- [x] Fix TMA multicast compatibility for indivisible shapes
- [x] Skip useless computation on M
- [x] NVRTC as a faster compiler
- [ ] Stolen JIT cache
- [ ] Sanitizer for testing
- [x] Weight gradient kernels for dense models
- [x] Weight gradient kernels for MoE models
- [ ] Better
get_best_configsmodeling - [ ] Utility kernels for MoE models (maybe with tile-lang)
- [ ] CUDA PDL support
- [ ] More scaling granularity support via templates
- [ ] Larger TMA multicast size for some shapes
- [x] MMA template refactor with CUTLASS
- [ ] Optimizations for power efficiency
- [x] Remove shape limitations on N and K
- [ ] BF16 kernels
- [ ] Split/stream-k optimizations
Quick start
Requirements
- Hopper architecture GPUs,
sm_90amust be supported - Python 3.8 or above
- CUDA 12.3 or above
- But we highly recommend 12.8 or above for the best performance
- PyTorch 2.1 or above
- CUTLASS 3.6 or above (could be cloned by Git submodule)
Development
# Submodule must be cloned git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git # Make symbolic links for third-party (CUTLASS and CuTe) include directories python setup.py develop # Test JIT compilation python tests/test_jit.py # Test all GEMM implements (normal, contiguous-grouped and masked-grouped) python tests/test_core.py
Installation
python setup.py install
Then, import deep_gemm in your Python project, and enjoy!
Interfaces
Notices
This library exclusively contains GEMM kernels. It requires the LHS scaling factor to be TMA-aligned and transposed, and it only supports the NT format (non-transposed LHS and transposed RHS). For transposition or other FP8 casting operations, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves.
Normal dense GEMMs (non-grouped)
To perform a basic non-grouped FP8 GEMM, call the deep_gemm.gemm_fp8_fp8_bf16_nt function. For more details, please refer to the function documentation.
Grouped GEMMs (contiguous layout)
Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape.
For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (get_m_alignment_for_contiguous_layout()).
For more information, please refer to the m_grouped_gemm_fp8_fp8_bf16_nt_contiguous function documentation.
Grouped GEMMs (masked layout)
During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions.
Use m_grouped_gemm_fp8_fp8_bf16_nt_masked for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from DeepEP as input.
Utilities
The library provides some utility functions besides the above kernels:
deep_gemm.set_num_sms: set the maximum SM count to usedeep_gemm.get_num_sms: get the current SM maximum countdeep_gemm.get_m_alignment_for_contiguous_layout: get the group-level alignment requirement for grouped contiguous layoutdeep_gemm.get_tma_aligned_size: get the required TMA alignment sizedeep_gemm.get_col_major_tma_aligned_tensor: get a column-major TMA-aligned tensor
The library also provides some environment variables, which may be useful:
- General
DG_JIT_DEBUG:0or1, print more JIT…
Excerpt shown — open the source for the full document.
Notability
notability 1.0/10Routine repo fork, no notable traction