RepoAmazon (Nova)Amazon (Nova)published May 27, 2026seen 5d

amazon-science/dualkv-flash-attn-for-rl

Python

Open original ↗

Captured source

source ↗

amazon-science/dualkv-flash-attn-for-rl

Description: Implementation of DualKV: Shared-Prompt Flash Attention for Efficient RL Training with Large Rollouts and Long Contexts

Language: Python

License: NOASSERTION

Stars: 2

Forks: 3

Open issues: 2

Created: 2026-05-27T17:38:58Z

Pushed: 2026-06-05T18:11:44Z

Default branch: main

Fork: no

Archived: no

README:

DualKV: Shared-Prompt Flash-Attention for RL Training

Code release for *"DualKV: Shared-Prompt Flash-Attention Kernels for Efficient Policy Updates in RL Training"*.

DualKV deduplicates shared prompts in GRPO/DAPO training — instead of computing attention over N*(P+R) tokens, it computes over P + N*R, yielding up to 6x kernel speedup and 2x end-to-end throughput on long-context RL workloads. This release includes the custom flash-attention kernels, veRL integration (with Ulysses Sequence Parallelism support), and scripts to reproduce all paper experiments.

Repository Structure

├── flash-attention/ # FlashAttention-2 (commit 41b2ef6) with DualKV kernels applied
├── verl/ # veRL v0.7.0 with DualKV integration applied
├── experiments/ # Benchmarks, training scripts, reward functions
├── LICENSE # CC-BY-NC-4.0
└── THIRD_PARTY_LICENSES

Key implementation files:

  • Forward kernel: flash-attention/csrc/flash_attn/src/flash_fwd_kernel_dualkv_training.h
  • Backward kernel: flash-attention/csrc/flash_attn/src/flash_bwd_kernel_dualkv_training.h
  • Python interface: flash-attention/flash_attn/flash_attn_interface.py (search for dualkv)
  • veRL actor integration: verl/verl/workers/actor/dp_actor.py (search for _dualkv)
  • Attention monkey-patch + SP: verl/verl/models/transformers/monkey_patch.py (DualKV + Ulysses all-to-all)
  • SP correctness test: experiments/test_dualkv_sp_correctness.py

Hardware Requirements

| Experiment | GPUs | |------------|------| | Kernel benchmarks (Table 1, Table 2) | 1x H100-80GB | | Qwen3-8B end-to-end (Table 5, Table 8) | 8x H100-80GB | | Qwen3-14B end-to-end | 8x H100-80GB | | DAPO end-to-end (Table 7) | 8x H100-80GB | | Qwen3-30B-A3B multi-node (Table 3) | 16x H100-80GB (2 nodes) | | Memory scaling sweep | 1x H100-80GB |

Software Environment

| Package | Version | |---------|---------| | Python | 3.12 | | PyTorch | 2.9.0+cu128 | | CUDA | 12.8 | | flash-attn | 2.8.4 (included, with DualKV) | | veRL | 0.7.0 (included, with DualKV) | | vLLM | 0.12.0 | | Ray | 2.55.0 | | Transformers | 4.57.6 |

Setup

git clone dualkv && cd dualkv
python3 -m venv .venv && source .venv/bin/activate
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu128

Install Flash Attention (with DualKV kernels)

cd flash-attention
pip install ninja numpy packaging
git clone --depth 1 https://github.com/NVIDIA/cutlass.git csrc/cutlass
pip install -e . --no-build-isolation
cd ..

Verify: python -c "from flash_attn import flash_attn_dualkv_varlen_func; print('OK')"

Install veRL (with DualKV integration)

cd verl
pip install -e .
cd ..

(Optional) Flash Attention 3

Only needed to reproduce FA3 baseline rows in Table 5 and Table 7:

git clone https://github.com/Dao-AILab/flash-attention.git /tmp/flash-attention-3
cd /tmp/flash-attention-3 && git checkout v3.0.0 && cd hopper && pip install -e .

Verify: python -c "from flash_attn_interface import flash_attn_func; print('FA3 OK')"

(Optional) Prefix Grouper

Only needed to reproduce the Prefix Grouper baseline in Table 2:

pip install git+https://github.com/CASIA-IVA-Lab/PrefixGrouper.git

Remaining Dependencies

pip install vllm==0.12.0 ray==2.55.0 wandb pandas pyarrow

Models and Data

WORKDIR=/path/to/your/workdir

# Models
huggingface-cli download Qwen/Qwen3-8B --local-dir ${WORKDIR}/models/Qwen3-8B
huggingface-cli download Qwen/Qwen3-14B --local-dir ${WORKDIR}/models/Qwen3-14B
huggingface-cli download Qwen/Qwen3-30B-A3B --local-dir ${WORKDIR}/models/Qwen3-30B-A3B

# Data
python experiments/preprocess_longreason.py --local_save_dir ${WORKDIR}/data/longreason
python experiments/preprocess_quality.py --local_save_dir ${WORKDIR}/data/quality

Reproducing Experiments

Set environment before running any script:

export WORKDIR=/path/to/your/workdir
export WANDB_API_KEY=your_key # optional, scripts fall back to console logging

Notation: mb = micro-batch size (prompt groups per training step), P = prompt length, N = number of responses per prompt, R = response length, SP = Ulysses sequence parallelism degree, DP = data parallelism degree, FA2/FA3 = FlashAttention-2/3.

Table 1: Kernel-Level Benchmarks (1x H100 or A100)

Isolated DualKV vs FA2 attention kernel timing (fwd + bwd), fp16.

CUDA_VISIBLE_DEVICES=0 python experiments/reproduce_table1.py

Expected output (H100-80GB):

N P | FA2 fwd FA2 bwd FA2 f+b | DK fwd DK bwd DK f+b | fwd bwd f+b
28 4096 | 49.4 165.8 215.3 | 34.4 98.7 133.1 | 1.44x 1.68x 1.62x
28 16384 | 425.0 1325.8 1750.8 | 120.1 347.6 467.7 | 3.54x 3.81x 3.74x
16 32768 | 857.7 2645.8 3503.4 | 174.5 504.9 679.4 | 4.91x 5.24x 5.16x
28 32768 | 1500.9 4609.0 6109.9 | 259.8 758.4 1018.2 | 5.78x 6.08x 6.00x
16 65536 | OOM OOM OOM | 454.2 1277.7 1731.8 | inf inf inf

Table 2: Single-Layer DualKV vs Prefix Grouper vs FA2 (1x H100)

Single Qwen3-8B decoder layer fwd+bwd with realistic response lengths. Prefix Grouper is self-implemented (no external package needed).

CUDA_VISIBLE_DEVICES=0 python experiments/reproduce_table2.py

Paper Table 2 reports configs: (P=5K, mb=32), (8K, 16), (16K, 8), (32K, 4). The script sweeps the full P x mb grid and marks paper configs with *.

Single-Step Full-Model Benchmark (8x H100)

torchrun --standalone --nproc-per-node 8 experiments/benchmark_qwen3_single_step.py \
--model ${WORKDIR}/models/Qwen3-8B --path both

Table 5: End-to-End GRPO (Qwen3-8B, 8x H100)

| Config | Script | |--------|--------| | FA2 mb=4 (baseline) | bash experiments/run_qwen3_8b_longreason_fa2.sh | | FA3 mb=4 | bash experiments/run_qwen3_8b_longreason_fa3.sh | | DualKV mb=4 | bash experiments/run_qwen3_8b_longreason_dualkv_mb4.sh | | DualKV mb=8 | bash experiments/run_qwen3_8b_longreason_dualkv_mb8.sh |

Table 7: End-to-End DAPO (Qwen3-8B, 8x H100)

| Config | Script | |--------|--------| | FA2 mb=4 | `bash…

Excerpt shown — open the source for the full document.

Notability

notability 3.0/10

Low-star research repo from Amazon.