OpenBMB/BMTrain
Python
Captured source
source ↗OpenBMB/BMTrain
Description: Efficient Training (including pre-training and fine-tuning) for Big Models
Language: Python
License: Apache-2.0
Stars: 624
Forks: 88
Open issues: 10
Created: 2021-12-01T02:58:58Z
Pushed: 2026-04-23T02:43:21Z
Default branch: main
Fork: no
Archived: no
README:
What's New
- 2024/02/26 BMTrain 1.0.0 released. Code refactoring and Tensor parallel support. See the detail in [update log](docs/UPDATE_1.0.0.md)
- 2023/08/17 BMTrain 0.2.3 released. See the [update log](docs/UPDATE_0.2.3.md).
- 2022/12/15 BMTrain 0.2.0 released. See the [update log](docs/UPDATE_0.2.0.md).
- 2022/06/14 BMTrain 0.1.7 released. ZeRO-2 optimization is supported!
- 2022/03/30 BMTrain 0.1.2 released. Adapted to OpenPromptand OpenDelta.
- 2022/03/16 BMTrain 0.1.1 has publicly released the first stable version, which fixes many bugs that were in the beta version.
- 2022/02/11 BMTrain 0.0.15 has publicly released the first beta version.
Overview
BMTrain is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training.
Documentation
Our documentation provides more information about the package.
Installation
- From pip (recommend) : `
pip install bmtrain
- From source code: download the package and run `
pip install .
Installing BMTrain may take a few to ten minutes, as it requires compiling the c/cuda source code at the time of installation. We recommend compiling BMTrain directly in the training environment to avoid potential problems caused by the different environments.
Usage
Step 1: Initialize BMTrain
Before you can use BMTrain, you need to initialize it at the beginning of your code. Just like using the distributed module of PyTorch requires the use of init_process_group at the beginning of the code, using BMTrain requires the use of init_distributed at the beginning of the code.
import bmtrain as bmt bmt.init_distributed( seed=0, # ... )
NOTE: Do not use PyTorch's distributed module and its associated communication functions when using BMTrain.
Step 2: Enable ZeRO Optimization
To enable ZeRO optimization, you need to make some simple replacements to the original model's code.
torch.nn.Module->bmtrain.DistributedModuletorch.nn.Parameter->bmtrain.DistributedParameter
And wrap the transformer blocks with bmtrain.Block.
Here is an example.
Original
import torch class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.empty(1024)) self.module_list = torch.nn.ModuleList([ SomeTransformerBlock(), SomeTransformerBlock(), SomeTransformerBlock() ]) def forward(self): x = self.param for module in self.module_list: x = module(x, 1, 2, 3) return x
Replaced
import torch import bmtrain as bmt class MyModule(bmt.DistributedModule): # changed here def __init__(self): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here self.module_list = torch.nn.ModuleList([ bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now bmt.Block(SomeTransformerBlock(), zero_level=3) # changed here, support 2 and 3 now ]) def forward(self): x = self.param for module in self.module_list: x = module(x, 1, 2, 3) return x
Step 3: Enable Communication Optimization
To further reduce the extra overhead of communication and overlap communication with computing time, TransformerBlockList can be used for optimization.
You can enable them by making the following substitutions to the code:
torch.nn.ModuleList->bmtrain.TransformerBlockListfor module in self.module_list: x = module(x, ...)->x = self.module_list(x, ...)
Original
import torch import bmtrain as bmt class MyModule(bmt.DistributedModule): def __init__(self): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = torch.nn.ModuleList([ bmt.Block(SomeTransformerBlock()), bmt.Block(SomeTransformerBlock()), bmt.Block(SomeTransformerBlock()) ]) def forward(self): x = self.param for module in self.module_list: x = module(x, 1, 2, 3) return x
Replaced
import torch import bmtrain as bmt class MyModule(bmt.DistributedModule): def __init__(self): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = bmt.TransformerBlockList([ # changed here bmt.Block(SomeTransformerBlock()), bmt.Block(SomeTransformerBlock()), bmt.Block(SomeTransformerBlock()) ]) def forward(self): x = self.param for module in self.module_list: x = module(x, 1, 2, 3) return x
Step 4: Launch Distributed Training
BMTrain uses the same launch command as the distributed module of PyTorch.
You can choose one of them depending on your version of PyTorch.
${MASTER_ADDR}means the IP address of the master node.${MASTER_PORT}means the port of the master node.${NNODES}means the total number of nodes.${GPU_PER_NODE}means the number of GPUs per node.${NODE_RANK}means the rank of this node.
torch.distributed.launch
$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.pytorchrun
$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.pyFor more information, please refer to the documentation.
Example
We provide an…
Excerpt shown — open the source for the full document.