NVIDIA/TransformerEngine
Python
Captured source
source ↗NVIDIA/TransformerEngine
Description: A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit and 4-bit floating point (FP8 and FP4) precision on Hopper, Ada and Blackwell GPUs, to provide better performance with lower memory utilization in both training and inference.
Language: Python
License: Apache-2.0
Stars: 3389
Forks: 746
Open issues: 364
Created: 2022-09-20T15:20:26Z
Pushed: 2026-06-10T19:44:41Z
Default branch: main
Fork: no
Archived: no
README: .. Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
|License|
Transformer Engine ==================
Quickstart _ | Installation _ | User Guide _ | Examples _ | Convergence _ | Integrations _ | Release notes _
Latest News ===========
- [12/2025]
NVIDIA Nemotron 3: Efficient and Open Intelligence_ - trained with NVFP4 on Transformer Engine - [11/2025]
NVIDIA Blackwell Architecture Sweeps MLPerf Training v5.1 Benchmarks_ - [11/2025]
Scale Biology Transformer Models with PyTorch and NVIDIA BioNeMo Recipes_ - [11/2025]
FP8 Training of Large-Scale RL Models_ - [09/2025]
Pretraining Large Language Models with NVFP4_ - [09/2025]
Native FP8 Mixed Precision Training for Ling 2.0, Open Sourced!_ - [09/2025]
Faster Training Throughput in FP8 Precision with NVIDIA NeMo_ - [08/2025]
How we built DeepL's next-generation LLMs with FP8 for training and inference_ - [08/2025]
NVFP4 Trains with Precision of 16-bit and Speed and Efficiency of 4-bit_
Previous News _
What is Transformer Engine? =========================== .. overview-begin-marker-do-not-remove
Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better performance with lower memory utilization in both training and inference. On Blackwell GPUs, TE also supports MXFP8 (Microscaling FP8) and NVFP4 formats for even greater efficiency. TE provides a collection of highly optimized building blocks for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your framework-specific code. TE also includes a framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
As Transformer models scale to hundreds of billions of parameters across large language models, MoE architectures, and multimodal models, training and inference become increasingly memory and compute-intensive. Mixed-precision training, which combines single-precision (FP32) with lower precision formats, delivers significant speedups with minimal impact on accuracy. FP8, introduced with the Hopper GPU architecture, offers further performance gains over FP16 with no degradation in accuracy, and newer formats like MXFP8 and NVFP4 on Blackwell push efficiency even further.
TE integrates with popular LLM frameworks and provides optimizations that make low-precision training work seamlessly with advanced features like MoE, tensor/sequence/context parallelism, and fused operations. It provides a Python API consisting of modules to easily build a Transformer layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly simplifying mixed precision training for users.
Highlights ==========
- Easy-to-use modules for building Transformer layers with FP8 support
- Optimizations (e.g. fused kernels) for Transformer models
- Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs
- Support for MXFP8 and NVFP4 on NVIDIA Blackwell GPUs
- Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
Examples ========
PyTorch ^^^^^^^
.. code-block:: python
import torch import transformer_engine.pytorch as te from transformer_engine.common import recipe
Set dimensions.
in_features = 768 out_features = 3072 hidden_size = 2048
Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True) inp = torch.randn(hidden_size, in_features, device="cuda")
Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
Enable autocasting for the forward pass
with te.autocast(enabled=True, recipe=fp8_recipe): out = model(inp)
loss = out.sum() loss.backward()
JAX ^^^
Flax ~~~~
.. code-block:: python
import flax import jax import jax.numpy as jnp import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax from transformer_engine.common import recipe
BATCH = 32 SEQLEN = 128 HIDDEN = 1024
Initialize RNG and inputs.
rng = jax.random.PRNGKey(0) init_rng, data_rng = jax.random.split(rng) inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)
Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
Enable autocasting for the forward pass
with te.autocast(enabled=True, recipe=fp8_recipe): model = te_flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp): out = model.apply({'params':params, **other_vars}, inp) return jnp.mean(out)
Initialize models.
variables = model.init(init_rng, inp) other_variables, params = flax.core.pop(variables, 'params')
Construct the forward and backward function
fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
for _ in range(10): loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
For a more comprehensive tutorial, check out our Getting Started Guide _.
.. overview-end-marker-do-not-remove
Installation ============
System Requirements ^^^^^^^^^^^^^^^^^^^
- Hardware: Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere
- OS: Linux (official), WSL2 (limited support)
- Software:
- CUDA: 12.1+ (Hopper/Ada/Ampere), 12.8+ (Blackwell) with compatible NVIDIA drivers
- cuDNN: 9.3+
- Compiler: GCC 9+ or Clang 10+ with C++17 support
- Python: 3.12 recommended
- Source Build Requirements: CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+
- Notes: FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell)
Installation Methods ^^^^^^^^^^^^^^^^^^^^
Docker (Recommended) ^^^^^^^^^^^^^^^^^^^^ The quickest way to…
Excerpt shown — open the source for the full document.