ByteDance-Seed/Triton-distributed
Python
Captured source
source ↗ByteDance-Seed/Triton-distributed
Description: Distributed Compiler based on Triton for Parallel Systems
Language: Python
License: MIT
Stars: 1457
Forks: 150
Open issues: 46
Created: 2025-04-02T06:57:03Z
Pushed: 2026-04-22T09:57:16Z
Default branch: main
Fork: no
Archived: no
README:
You can get to know us better through the following channels👇
Triton-distributed
Original Triton README | [README in Chinese](README-cn.md)
Triton-distributed is a distributed compiler designed for computation-communication overlapping, which is based on OpenAI Triton.
Using Triton-distributed, programmers are able to develop efficient kernels comparable to highly-optimized libraries (including Distributed-GEMM and FLUX). Triton-distributed currently mainly targets Nvidia GPU and AMD GPU. It can also be ported to other hardware platforms. Feel free to contact us if you want to use Triton-distributed on your own hardware.
News
- 12/22/2025 ✨✨✨: Updated EP functions, support low-latency mode, token saving, and Mega-EP.
- 21/10/2025 🔥🔥🔥: Triton-distributed is presented at Triton Conference 2025, see the talk for details.
- 09/03/2025 ✨✨✨: Introduced Intra-Kernel Profiler, See the doc for details.
- 08/24/2025 ⚡⚡⚡: Support inference acceleration for ByteDance-Seed/Seed-OSS-36B-Instruct, achieving a 1.33x speedup.
- 08/13/2025 ✨✨✨: Introduced the MegaTritonKernel and provided a Qwen3 TP demo on H20/H800, See the doc for details.
- 08/06/2025 ✨✨✨: Support GEMM+AllReduce on H800 and support MoE operators on L20, see GEMM+AR Test and MOE Test for detail.
- 07/24/2025 🤖🤖🤖: Introduced end-to-end inference acceleration demo with unified support for both NVIDIA and AMD GPUs. See the doc for details.
- 07/11/2025 ✨✨✨: Fast AllReduce implemented with Triton-distributed, see AllReduce Test.
- 07/11/2025 ✨✨✨: Improved MoE operators for tensor parallel. See AG+MoE Test and MoE+RS Test.
- 07/11/2025 ✨✨✨: Triton 3.4 support with NVSHMEM4py (MR).
pip installis also supported without any need to modify NVSHMEM code. - 05/12/2025 🚀🚀🚀: Our paper
TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitivesaccepted by MLSys 2025.
Getting started
Install Triton-distributed
Method 1. From source
See [build from source](docs/build.md).
Method 2. Using pip
Prepare PyTorch container
docker run --name triton-dist --ipc=host --network=host --privileged --cap-add=SYS_ADMIN --shm-size=10g --gpus=all -itd nvcr.io/nvidia/pytorch:25.04-py3 /bin/bash docker exec -it triton-dist /bin/bash
Install Dependencies
pip3 install cuda.core==0.2.0 nvidia-nvshmem-cu12==3.3.9 Cython==0.29.24 nvshmem4py-cu12==0.1.2 pip3 install cuda-python==12.4 setuptools==69.0.0 wheel pybind11
Then, pip install triton-dist.
# Remove triton installed with torch
pip uninstall triton
pip uninstall triton_dist # remove previous triton-dist
rm -rf /usr/local/lib/python3.12/dist-packages/triton
# Install Triton-distributed
VERSION=v0.0.2 # use the latest version
pip install https://github.com/ByteDance-Seed/Triton-distributed/releases/download/${VERSION}/triton_dist-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whlHow to use Triton-distributed
Triton-distributed provides a set of easy-to use primitives to support the development of distributed compute-communication overlapping kernels. The primitives are divided into low-level primitives and high-level primitives. Currently, we have released our low-level primitives, and we plan to release high-level primitives in future.
[Triton-distributed Primitives](docs/primitives.md)
Using these primitives, users can program communication kernels easily. For example, a low-latency AllToAll (with better latency than DeepEP for inference) is shown below. The performance of this example on 32 H800 GPUs is 137us (128 tokens per rank, topk=8, hidden_size=7168, dtype=fp8), while DeepEP is 182 us (note: DeepEP doesn't use NVLink for inference).
@triton_dist.jit def all_to_all_kernel( data_src, data_dst, splits_src, splits_dst, signal, splits_cumsum, scale_src, scale_dst, rank: int, call_count: int, WITH_SCALE: tl.constexpr, WORLD_SIZE: tl.constexpr, HIDDEN: tl.constexpr, MAX_M: tl.constexpr, EXPERTS_PER_RANK: tl.constexpr, NUM_TOT_EXPERTS: tl.constexpr, ELEMENT_SIZE: tl.constexpr = 2, SCALE_ELEMENT_SIZE: tl.constexpr = 4, ): pid = tl.program_id(0) threadidx = tid(axis=0) exp_st = pid * EXPERTS_PER_RANK exp_ed = exp_st + EXPERTS_PER_RANK m_st = tl.load(splits_cumsum + exp_st) m_ed = tl.load(splits_cumsum + exp_ed) num_rows_cur_block = m_ed - m_st src_off = m_st dst_off = rank * MAX_M split_src_ptr = splits_src + exp_st off0 = exp_st + tl.arange(0, EXPERTS_PER_RANK) off1 = exp_st + tl.arange(0, EXPERTS_PER_RANK) + 1 cumsum_sts = tl.load(splits_cumsum + off0) cumsum_eds = tl.load(splits_cumsum + off1) tl.store(split_src_ptr + tl.arange(0,…
Excerpt shown — open the source for the full document.
Notability
notability 5.0/10New repo from ByteDance, decent stars.