ForkBasetenBasetenpublished Dec 19, 2025seen 5d

basetenlabs/TransformerEngine

forked from NVIDIA/TransformerEngine

Open original ↗

Captured source

source ↗
published Dec 19, 2025seen 5dcaptured 9hhttp 200method plain

basetenlabs/TransformerEngine

Description: A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit and 4-bit floating point (FP8 and FP4) precision on Hopper, Ada and Blackwell GPUs, to provide better performance with lower memory utilization in both training and inference.

License: Apache-2.0

Stars: 0

Forks: 0

Open issues: 1

Created: 2025-12-19T00:31:29Z

Pushed: 2026-04-08T08:15:59Z

Default branch: main

Fork: yes

Parent repository: NVIDIA/TransformerEngine

Archived: no

README: .. Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

|License|

Transformer Engine ==================

Quickstart _ | Installation _ | User Guide _ | Examples _ | FP8 Convergence _ | Integrations _ | Release notes _

Latest News ===========

  • [09/2025] Pretraining Large Language Models with NVFP4 _
  • [09/2025] Native FP8 Mixed Precision Training for Ling 2.0, Open Sourced! _
  • [09/2025] Faster Training Throughput in FP8 Precision with NVIDIA NeMo _
  • [08/2025] How we built DeepL's next-generation LLMs with FP8 for training and inference _
  • [08/2025] NVFP4 Trains with Precision of 16-bit and Speed and Efficiency of 4-bit _
  • [06/2025] Floating Point 8: An Introduction to Efficient, Lower-Precision AI Training _
  • [05/2025] Advanced Optimization Strategies for LLM Training on NVIDIA Grace Hopper _
  • [03/2025] Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 _
  • [03/2025] Measure and Improve AI Workload Performance with NVIDIA DGX Cloud Benchmarking _

.. image:: docs/examples/comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg :width: 600 :alt: Comparison of FP8 versus BF16 training, as seen in NVIDIA DGX Cloud Benchmarking Performance Explorer

  • [02/2025] Understanding the Language of Life's Biomolecules Across Evolution at a New Scale with Evo 2 _
  • [02/2025] NVIDIA DGX Cloud Introduces Ready-To-Use Templates to Benchmark AI Platform Performance _
  • [01/2025] Continued Pretraining of State-of-the-Art LLMs for Sovereign AI and Regulated Industries with iGenius and NVIDIA DGX Cloud _

Previous News _

What is Transformer Engine? =========================== .. overview-begin-marker-do-not-remove

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better performance with lower memory utilization in both training and inference. TE provides a collection of highly optimized building blocks for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your framework-specific code. TE also includes a framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.

As the number of parameters in Transformer models continues to grow, training and inference for architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for many deep learning models. Using mixed-precision training, which combines single-precision (FP32) with lower precision (e.g. FP16) format when training a model, results in significant speedups with minimal differences in accuracy as compared to FP32 training. With Hopper GPU architecture FP8 precision was introduced, which offers improved performance over FP16 with no degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is not available natively in frameworks today.

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly simplifying mixed precision training for users.

Highlights ==========

  • Easy-to-use modules for building Transformer layers with FP8 support
  • Optimizations (e.g. fused kernels) for Transformer models
  • Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs
  • Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later

Examples ========

PyTorch ^^^^^^^

.. code-block:: python

import torch import transformer_engine.pytorch as te from transformer_engine.common import recipe

Set dimensions.

in_features = 768 out_features = 3072 hidden_size = 2048

Initialize model and inputs.

model = te.Linear(in_features, out_features, bias=True) inp = torch.randn(hidden_size, in_features, device="cuda")

Create an FP8 recipe. Note: All input args are optional.

fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

Enable autocasting for the forward pass

with te.autocast(enabled=True, recipe=fp8_recipe): out = model(inp)

loss = out.sum() loss.backward()

JAX ^^^

Flax ~~~~

.. code-block:: python

import flax import jax import jax.numpy as jnp import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax from transformer_engine.common import recipe

BATCH = 32 SEQLEN = 128 HIDDEN = 1024

Initialize RNG and inputs.

rng = jax.random.PRNGKey(0) init_rng, data_rng = jax.random.split(rng) inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

Create an FP8 recipe. Note: All input args are optional.

fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)

Enable autocasting for the forward pass

with te.autocast(enabled=True, recipe=fp8_recipe): model = te_flax.DenseGeneral(features=HIDDEN)

def loss_fn(params, other_vars, inp): out = model.apply({'params':params, **other_vars}, inp) return jnp.mean(out)

Initialize models.

variables = model.init(init_rng, inp) other_variables, params = flax.core.pop(variables, 'params')

Construct the forward and backward function

fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

for _ in range(10): loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

For a more comprehensive tutorial, check out our Quickstart Notebook _.

.. overview-end-marker-do-not-remove

Installation ============

System…

Excerpt shown — open the source for the full document.

Notability

notability 2.0/10

Routine fork by same org