FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
Captured source
source ↗FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
⚡️ FlashAttention-4: up to 1.3× faster than cuDNN on NVIDIA Blackwell →
Introducing Together AI's new look →
🔎 ATLAS: runtime-learning accelerators delivering up to 4x faster LLM inference →
⚡ Together GPU Clusters: self-service NVIDIA GPUs, now generally available →
📦 Batch Inference API: Process billions of tokens at 50% lower cost for most models →
🪛 Fine-Tuning Platform Upgrades: Larger Models, Longer Contexts →
All blog posts
Research
Published 3/5/2026
FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
Authors
Ted Zadouri (Princeton University, Together AI), Markus Hoehnerbach (Meta), Jay Shah (Colfax Research), Timmy Liu (NVIDIA), Vijay Thakkar (Meta, Georgia Tech), Tri Dao (Princeton University, Together AI)
Table of contents
40+ Models Chosen for Production...40+ Models Chosen for Production...40+ Models Chosen for Production...
Links in this article
Paper Code
Presenting FlashAttention-4: [ Paper ] [ Code ]
Modern accelerators like Blackwell GPUs continue the trend of asymmetric hardware scaling , where tensor core throughput grows far faster than other resources such as shared memory bandwidth, special function units (SFUs) for transcendental operations like exponential, and general-purpose integer and floating-point ALUs. From the Hopper H100 to the Blackwell B200, for instance, BF16 tensor core throughput increases from 1 to 2.25 PFLOPs, while both the SFU count and shared memory bandwidth remains unchanged. This scaling asymmetry has profound implications for optimizing complex kernels like attention for the Blackwell architecture. At its core, attention comprises two GEMMs $(S=Q \cdot K^T$ and $O=P \cdot V)$ with softmax in-between; in practice, it also involves substantial plumbing and bookkeeping: data movement, synchronization, layout transforms, element-wise ops, scheduling, masking, etc. A naive viewpoint on attention might be that the speed of the GEMMs completely controls the kernel performance and one can effectively disregard these other attention components, at least to first order. However, doing a “feeds and speeds” analysis for B200 in fact shows the opposite: the main performance bottleneck lies not in how fast the tensor cores can do MMA, but rather (a) in the SFU units for softmax exponential during the FWD computation, and (b) in the shared-memory traffic during the BWD computation. In this blog post, we present FlashAttention-4 , an algorithm and kernel co-design that maximizes overlap between matmul and these other resource bottlenecks. On B200 with BF16, it reaches up to 1605 TFLOPs/s (71% utilization), up to 1.3× faster than cuDNN version 9.13 and 2.7× faster than Triton. Our main algorithmic and kernel co-design ideas are as follows: New pipelining for maximum overlap : New forward and backward software pipelines that exploit Blackwell fully asynchronous MMA and larger tile sizes, overlapping tensor cores, softmax exponential, and memory operations. Forward (FWD) pass : A software emulation of the exponential function implemented via polynomial approximation on FMA units to mitigate the exponential bottleneck, plus conditional online softmax rescaling. Backward (BWD) pass : Storing intermediate results in tensor memory to relieve shared-memory traffic, combined with Blackwell's new 2-CTA MMA mode to reduce shared memory traffic further and also cut atomic reduction in half, and additional support for deterministic execution mode for reproducible training. Scheduling: New tile scheduler to mitigate load imbalance from causal mask and variable sequence length.
New hardware features on Blackwell Tensor memory (TMEM): On B200, each of the 148 SMs has 256 KB of TMEM, an on chip scratchpad wired into the tensor cores for warp synchronous intermediate storage. Fully asynchronous 5th gen tensor cores: tcgen05.mma is asynchronous and accumulates in TMEM. For BF16 and FP16, the largest single CTA UMMA tile is 128×256×16, which is about 2× larger than the largest Hopper WGMMA atom. UMMA is launched by a single thread, easing register pressure and making larger tiles and deeper pipelines practical without the spilling pain points of Hopper warpgroup MMA. This also makes warp specialization more viable, with some warps moving tiles while others issue MMA to overlap matrix multiply accumulate with softmax and memory traffic. tcgen05.mma can also source operand A from TMEM. 2-CTA MMA.: Blackwell can execute one UMMA across a CTA pair in the same cluster, spanning the TMEM of both peer CTAs. One thread in the leader CTA launches the MMA, but both CTAs must stay active while it is in flight. This scales the MMA tile dimension up to 256×256×16 by splitting M and N across the pair, reducing redundant traffic and lowering per CTA footprint. The CTA group size, 1 or 2, must remain constant across TMEM and tensor core operations within a kernel.
Feeds and Speeds For M=N=D=128 Feeds on B200 (per SM): Tensor Cores (BF16): $\frac{8192 \text{ ops}}{cycle}$ Exponential unit: $\frac{16 \text{ ops}}{cycle}$ Shared Memory traffic: $\frac{128 \text{ bytes}}{cycle}$
Speeds (clock-cycles per tile): Forward (2 MMAs + MN exp) Tensor Cores: $1024$ Exp: $1024$ SMEM: $768$
Backward (5 MMAs + MN exp): 1-CTA Tensor Cores: $2560$ Exp: $1024$ SMEM: $3328$
Takeaway: Forward is bottlenecked by compute and exponential, backward is bottlenecked by shared memory bandwidth. So we overlap softmax with MMA in the forward pass and reduce shared memory traffic in the backward pass. Forward pass: New softmax pipelining with conditional rescaling The forward pass has two matmuls, QK^T and PV . On Blackwell, tensor cores got much faster, but the exponential unit (MUFU.EX2) did not. So softmax is no longer “just the thing between the two matmuls”, it is a bottleneck that must be carefully pipelined. The FWD pass in short: Ping-pong schedule $2x$ Q and $2x$ O tiles per CTA: maximize overlap between MMA and Softmax 2x softmax warpgroups: per tile softmax with synchronization to not overlap when computing exponential Software emulation of $2^x$: distribute exp computation across hardware's MUFU and software emulated on FMA Store P in TMEM in stages: mitigate register pressure
Correction warpgroup: designated "correction" warpgroup to perform rescaling to remove from critical path Online softmax (conditional) rescaling:…
Excerpt shown — open the source for the full document.
Notability
notability 4.0/10Low traction research paper; future potential