inclusionAI/asystem-amem
C++
Captured source
source ↗inclusionAI/asystem-amem
Description: A NCCL extension library, designed to efficiently offload GPU memory allocated by the NCCL communication library.
Language: C++
License: Apache-2.0
Stars: 109
Forks: 11
Open issues: 6
Created: 2025-11-26T09:00:12Z
Pushed: 2025-12-17T08:38:28Z
Default branch: main
Fork: no
Archived: no
README:
AMem NCCL-Plugin: Transparent NCCL GPU Memory Offloading and Restoration
TL;DR Technical Overview
NCCL stands for NVIDIA Collective Communications Library. It is the core communication library for multi-GPU and multi-node distributed deep learning, providing highly efficient collective communication operations such as AllReduce and Broadcast.
AMem NCCL-Plugin is a self-developed NCCL extension library by Ant Group’s ASystem team. It introduces two memory management APIs—ncclPause() and ncclResume()—to address a critical challenge in reinforcement learning (RL) workflows: the inability to efficiently offload GPU memory allocated by the NCCL communication library. Through a lightweight plugin approach, AMem enables transparent offloading and restoration of NCCL memory used by training/inference engines while preserving existing NCCL communication connections1. These advantages have already been validated in RL training for Ring-1T, a trillion-parameter model.
The benefits of AMem NCCL-Plugin are demonstrated in two key aspects:
+ Memory Savings: By identifying and resolving cross-rank GPU memory cross-references within the NCCL communication library, AMem correctly implements transparent memory release and restoration. During transitions between training and inference, it can free over 10 GB of GPU memory per card (Hopper architecture) while maintaining communication group connectivity. + Extreme Efficiency: Since communication group connections are preserved, switching between training and inference only requires offloading and restoring NCCL metadata—no need to rebuild communication connections (which typically takes seconds). This reduces typical transition latency to under 1 second.
Comparison with Community Solutions on Hopper Architecture GPUs:
| System | Solution | Memory Saved | Per-step Offload/Reload Time | | --- | --- | --- | --- | | Slime | Clean NCCL GPU memory by destroying and recreating the training engine's communication group | Inference: No saving (2 GB left) Training: Saves 10 GB+ | Several seconds | | OpenRLHF | Does not support offloading NCCL GPU memory | Inference: No saving (2 GB left) Training: No saving (10 GB+ left) | 0s | | AMem | Offload and restore NCCL GPU memory via Plugin | Inference: Saves 2 GB Training: Saves 10 GB+ | 2, without rebuilding NCCL communication groups. The amount of offloadable memory depends on:
+ Cluster scale + Number of collective communication groups3 (especially AlltoAll) + Parallel strategy (typically 3D–5D) + CUDA/NCCL version
In large-scale tasks, NCCL memory overhead can reach 10–20 GB per GPU. With AMem, restoration latency is typically under 1 second4.
 
_Figure 8: AMem NCCL-Plugin nearly fully offloads NCCL memory (left/right: different GPU types)_
_Note 2: CUDA context memory (~800 MB) is __not offloaded__, as it’s shared between training/inference processes._ _Note 3: Common collective communication primitives include: Broadcast, Scatter, Gather, Reduce, AllGather, AllReduce, ReduceScatter, AlltoAll, etc._
_Note 4: First offload is slower (due to CPU pinned buffer allocation); subsequent operations take First compilation takes ~10 minutes; see README for details.
Build Steps
# Recommend docker nvcr.io/nvidia/pytorch:25.08-py3 cd asystem-amem/ git submodule init git submodule update ./build.sh
NCCL Memory Statistics (independent of pause/resume): call ncclMemStats()
AMEM groupID:170 pid:197780 caller_1 allocBytes:3024093184 AMEM groupID:170 pid:197780 caller_3 allocBytes:201326592 AMEM groupID:170 pid:197780 caller_7 allocBytes:2818572288 AMEM groupID:170 pid:197780 total allocBytes:6043992064 (5764 MB)
Key Environment Variables
NCCL_CUMEM_ENABLE=1 # Required: enable NCCL CUMEM AMEM_ENABLE=1 # Enable NCCL memory offload/restore AMEM_GROUPID=xxx # Assign distinct group IDs for training/inference processes
When integrating with RL frameworks, pass these variables to Ray or the training/inference framework.
Optional Environment Variables
AMEM_NCCL_OFFLOAD_FREE_TAG=7 # Directly free P2P buffers without CPU offload GMM_LOG=3 # Log level (default: 3/INFO; max: 5)
Unit Testing
Based on nccl-tests, validate dynamic memory offload/restore under typical parallel patterns (AllReduce, AllGather, AlltoAll, etc.).
+ Framework-independent + Takes ~10 minutes post-compilation + Requires minor modifications: insert calls to ncclPause()/ncclResume()
Original tests: https://github.com/NVIDIA/nccl-tests
# Run quick tests about nccl mem offloading/resume export MPI_HOME=your/openmpi/home bash ./run.sh
Test run example:

Framework Integration
AMem NCCL-Plugin does not affect normal NCCL usage but adds new APIs:
+ ncclPause(): Synchronously releases NCCL-allocated GPU memory in the current process. + ncclResume(): Synchronously restores all memory previously released by ncclPause(). + ncclSetGroupID(): Sets a process group ID for the current process. + ncclMemStats(): Reports NCCL memory usage and breakdown.
Additional Notes:
+ ncclPause/ncclResume are idempotent (safe for repeated calls). + The framework must ensure cross-process synchronization so all ranks complete offload/restore. + Supports multiple communication groups per process (e.g., 3D/4D parallelism). + If only one task runs at a time (e.g., inference-only or training-only), groupID is unnecessary.
PyNCCL Integration
Many upper-layer applications (e.g., SGLang, vLLM) use PyNCCL—a Python wrapper that loads NCCL’s dynamic library and exposes APIs via function handles.
SGLang Example
Modify pynccl and pynccl_wrapper to load the three new function handles. ( ncclComm parameter can be set to NULL. )
# ncclResult_t ncclPause(ncclComm_t comm);
Function("ncclPause", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclResume(ncclComm_t comm);…Excerpt shown — open the source for the full document.
Notability
notability 5.0/10New repo with moderate stars