NVIDIA/TransformerEngine v2.14
NVIDIA/TransformerEngine
Captured source
source ↗published Apr 21, 2026seen 5dcaptured 8hhttp 200method plain
v2.14
Repository: NVIDIA/TransformerEngine
Tag: v2.14
Published: 2026-04-21T21:57:36Z
Prerelease: no
Release notes:
Transformer Engine v2.14 Release Notes
Key Features and Enhancements
- [PyTorch] Added multiple CPU overhead optimizations across the framework integration to reduce per-step Python/host overhead. (#2559) (#2724)
- [C, PyTorch] Added BF16 and MXFP8 grouped GEMM support with on-device group sizes. (#2748) (#2669)
- [PyTorch] Added a fused GEMM + SwiGLU grouped MLP for MXFP8 to accelerate MoE forward/backward. (#2769)
- [PyTorch] Added support for a single-parameter
GroupedLinearconfiguration, where the weights of all experts are stored in a single parameter, which reduces CPU overheads. (#2731) - [PyTorch] Added backwards-compatible checkpoint support for the new single-parameter
GroupedLinear. (#2761) - [PyTorch] Extended the fused attention API to optionally return softmax
Statsalways andMaxwhenreturn_max_logit=True, exposing more cuDNN intermediates to users. (#2677) - [PyTorch] Enabled SM120 support for the fused attention path when cuDNN >= 9.18.1 is available. (#2693)
- [PyTorch] Added support for MXFP8BlockScaling and Float8BlockScaling quantized weight in
FusedAdam. (#2753) - [PyTorch] Added CUDA graph-compatible
multi_tensor_scale_tensorAPI in the optimizer. (#2594) - [PyTorch] Enabled CUDA Graph capture of modules with CPU offloading. (#2435)
- [PyTorch] Added support for non-FP32
params_dtypewhen using QK-normalization. (#2718) - [PyTorch] Added precision debug-tools support for quantized model parameters. (#2141)
- [JAX] Added a JAX-side API to invoke the fused MoE router kernels. (#2711)
- [JAX] Integrated BF16 grouped GEMM with on-device group sizes. (#2680)
- [JAX] Added a Collective GEMM (CGEMM) implementation with FP8 and MXFP8 support. (#2740)
- [JAX] Added Shardy support to the Collective GEMM (CGEMM) path. (#2714)
- [JAX] Improved the performance of the permutation kernels for the JAX 0.8.0 and newer. (#2741)
- [C] Enabled the fused RMSNorm
dLN + addbackward path through cuDNN for faster fused-residual normalization. (#2778) - [C] Added a grouped MXFP8 quantization kernel, including grouped dbias support. (#2738) (#2674)
- [C] Enabled dequantization from an MXFP8 tensor that only carries column-wise data. (#2712)
- [C/PyTorch] Improved the performance of the NVFP4 recipe by fusing row-cast / RHT / transpose / column-cast. (#2555)
- [C] Made the number of Philox rounds for stochastic rounding configurable. (#2751)
- [Documentation] Added a documentation page describing CPU offloading in Transformer Engine. (#2520)
- [Documentation] Updated the documentation to describe the current cuDNN sliding-window attention support. (#2624)
- [Documentation] Improved error messages across the C, PyTorch, and JAX layers. (#2705)
- [Documentation] Added a custom-feature tutorial for the precision debug tools. (#2216)
- [Documentation] Added documentation for the operator fuser API. (#2447)
- [PyTorch, Documentation] Added end-to-end examples for
fused_adam,quantized_model_init, and FSDP2 usage. (#2698) (#2662)
Fixed Issues
- [PyTorch] FSDP2 / Megatron-FSDP / DCP (distributed checkpointing): when model parameters are
DTensors, ensure optimizer states are alsoDTensors for correct sharded checkpoints. (#2795) - [PyTorch] Fixed async DCP checkpointing for
Float8Tensorparameters. (#2721) - [PyTorch] Fixed the issue with
cross_entropy_forwardproducing wrong answers for non-contiguous logits. (#2746) - [PyTorch] Fixed the excessive memory usage issue when using operator fuser. (#2750)
- [PyTorch] Fixed a precision-debug-tools crash when
tp_group=None. (#2733) - [PyTorch] Fixed Flash Attention 3 API compatibility for the window-size parameters. (#2704)
- [PyTorch] Fixed the initialization of the learnable
softmax_offsetparameter inDotProductAttentionto zero-initialization. (#2694) - [PyTorch] Fixed the error with FP8 block scaling when sequence parallelism is enabled and local tensor dimensions are not divisible by 128. (#2637)
- [PyTorch] Added a clear error when constructing
LayerNormLinearwith row-wise tensor parallelism (an unsupported configuration).…
Excerpt shown — open the source for the full document.
Notability
notability 5.0/10Routine library version update