RepoNebiusNebiuspublished Jan 10, 2025seen 5d

nebius/kvax

Python

Open original ↗

Captured source

source ↗
published Jan 10, 2025seen 5dcaptured 9hhttp 200method plain

nebius/kvax

Description: A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism.

Language: Python

License: Apache-2.0

Stars: 167

Forks: 9

Open issues: 3

Created: 2025-01-10T13:34:47Z

Pushed: 2025-11-11T17:22:27Z

Default branch: main

Fork: no

Archived: no

README:

Kvax: fast and easy-to-use flash attention implementation for JAX

Kvax is an open-source library offering fast and efficient attention operations for the JAX framework. Built with Flash Attention 2 algorithms implemented in the Triton language, it is optimised for high-performance attention computation with document masks and supports context parallelism. Kvax is designed to perform exceptionally well in distributed training scenarios on long sequences using FSDP/HSDP sharding.

More technical details in our blogpost: https://nebius.com/blog/posts/kvax-open-source-flash-attention-for-jax

Table of Contents:

  • [Key Concepts of Kvax Implementation](#key-concepts-of-kvax-implementation)
  • [Kvax Features](#kvax-features)
  • [Kvax Results](#kvax-results)
  • [How to install](#how-to-install)
  • [How to use](#how-to-use)
  • [Package Description](#package-description)
  • [Benchmarks](#benchmarks)
  • [Limitations](#limitations)
  • [Contributing](#contributing)
  • [Citation](#citation)
  • [License](#license)

Key Concepts of Kvax Implementation

Document Mask Optimisation

When training transformer models on long sequences, a significant amount of compute is spent on attention operations due to the quadratic complexity of the attention algorithm. Flash Attention algorithm offers hardware-specific optimisations to significantly reduce latency and memory requirements for these operations.

During training on long sequences, dense packing is often used to maximise compute resource utilisation. In this approach, multiple data points are packed into a single sequence while avoiding cross-sequence attention contamination. The main idea is to calculate only the blocks of attention weights that include tokens which should attend to each other while skipping other blocks. Various methods can efficiently handle this, with PyTorch's FlexAttention being one example. Kvax takes a similar approach to achieve high performance in these scenarios.

Context Parallelism

Using long sequences during training can also lead to high GPU memory consumption for storing layer activations. Context parallelism helps solve this problem, speeding up the computations and reducing memory required for layer activations.

There are several approaches to implementing context parallelism for transformer architectures, such as RingAttention and all-gather based method. The all-gather based method, described in the Llama 3 training paper, performs an all-gather on the key and value tensors, collecting tensors before attention computation due to their lower memory requirements enabled by GQA. This method is particularly well-suited for document masks, and Kvax leverages it in its implementation.

Kvax Features

  • Block-wise Attention Masks: Like FlexAttention, our implementation builds the attention mask once per forward-backward pass, reusing it across layers. Our high-performance Triton kernel builds this mask blockwise, and does not require O(seq_len^2) GPU memory.
  • Optimised Memory Storage: Kvax stores attention masks in block-wise format, requiring 3 * 4 * batch_size * seq_len // block_size * 4 bytes (block_size is typically 64 or 128).
  • Skipping Pad Tokens: Kvax skips blocks consisting entirely of padding tokens. See the "How to Use" section for details on defining padding tokens.
  • Context Parallelism: Kvax balances tokens across GPUs to ensure equal attention operation loads, accounting for causal masks. This feature is described in Llama 3 training paper and fully integrates with document mask optimisations.

Kvax Results

![Comparison of attention implementations with causal masks; forward pass only](assets/attn_doc.png)

![Comparison of attention implementations with causal masks; forward + backward pass](assets/attn_doc_bwd.png)

More details on Kvax benchmarking and its results can be found in the blogpost.

How to install

Install the latest stable release from pip:

pip install kvax

Note: The automatically installed versions of Triton and JAX-Triton might not be compatible. If you encounter an error while running the provided benchmarks, please ensure that you install compatible versions manually. For benchmarking, we used `triton==3.1` and `jax-triton==0.2.0`.

How to use

First, ensure that the position of every padding token is marked with PADDING_SEGMENT_ID in the query_segment_ids and kv_segment_ids tensors:

from kvax.utils import PADDING_SEGMENT_ID

# In this example, the sequence length is 8, and there are 2 padding tokens.
pad_token_id = 128001
input_ids = [6151, 0, 52043, 710, 374, 1618, pad_token_id, pad_token_id]
query_segment_ids = [0, 0, 0, 0, 0, 0, PADDING_SEGMENT_ID, PADDING_SEGMENT_ID]
kv_segment_ids = [0, 0, 0, 0, 0, 0, PADDING_SEGMENT_ID, PADDING_SEGMENT_ID]

Then, kvax functions can be used in the transformer code:

import flax.linen as nn
from kvax.ops import (
create_attention_mask,
flash_attention,
)
from kvax.utils import (
attention_specs,
permute_tokens_context_parallelism,
unpermute_tokens_context_parallelism,
)

class AttentionLayer(nn.Module):
def __call__(
self,
embedding,
query_positions,
query_segment_ids,
kv_positions,
kv_segment_ids,
attn_mask,
):
query, key, value = ...
scale = ...

# Call the Flash Attention op
attn_out = flash_attention(
query=query,
key=key,
value=value,
query_positions=positions,
query_segment_ids=segment_ids,
kv_positions=kv_positions,
kv_segment_ids=kv_segment_ids,
mask=attn_mask,
assume_sequential_positions=self.config.assume_sequential_positions,
scale=scale,
# Mesh is defined as a global context
# mesh=mesh,
)

out = ...
return out

class Transformer(nn.Module):
...
def setup(self):
self.attn_layers = [AttentionLayer(...) for _ in…

Excerpt shown — open the source for the full document.

Notability

notability 5.0/10

New repo with moderate stars