QwenLM/online_merging_optimizers
Python
Captured source
source ↗QwenLM/online_merging_optimizers
Description: Implementations of online merging optimizers proposed by Online Merging Optimizers for Boosting Rewards and Mitigating Tax in Alignment
Language: Python
License: Apache-2.0
Stars: 83
Forks: 14
Open issues: 2
Created: 2024-05-28T03:57:37Z
Pushed: 2024-06-19T12:52:34Z
Default branch: main
Fork: no
Archived: no
README:
Online Merging Optimizers
*Keming Lu, Bowen Yu, Fei Huang, Yang Fan, Runji Lin, Chang Zhou*
Qwen, Alibaba Inc.
This is the repository contains core implementations of the online merging optimizers, proposed by [Online Merging Optimizers for Boosting Rewards and Mitigating Tax in Alignment]().
Update
Thanks for the interest in online merging optimizers! We are working on a PR to merge our optimizers into LLaMa-Factory. Stay tune!
Introduction
Effectively aligning Large Language Models (LLMs) with human-centric values while preventing the degradation of abilities acquired through Pre-training and Supervised Fine-tuning (SFT) poses a central challenge in Reinforcement Learning from Human Feedback (RLHF). In this paper, we first discover that interpolating RLHF and SFT model parameters can adjust the trade-off between human preference and basic capabilities, thereby reducing the alignment tax at the cost of alignment reward. Inspired by this, we propose integrating the RL policy and SFT models at each optimization step in RLHF to continuously regulate the training direction, introducing the Online Merging Optimizer. Specifically, we merge gradients with the parameter differences between SFT and pretrained models, effectively steering the gradient towards maximizing rewards in the direction of SFT optimization. We demonstrate that our optimizer works well with different LLM families, such as Qwen and LLaMA, across various model sizes ranging from 1.8B to 8B, various RLHF algorithms like DPO and KTO, and existing model merging methods. It significantly enhances alignment reward while mitigating alignment tax, achieving higher overall performance across 14 benchmarks. A more detailed manuscript in [paper](assets/online_merging_arxiv_review.pdf).
Installation
You can install the source codes from this repository
git clone https://github.com/QwenLM/online_merging_optimizers cd online_merging pip install -e .
Usage
Applying online merging optimizers
There are two major differences in the initialization of online merging optimizers compared with original AdamW:
- Building a mapping between parameters and names: When collecting the parameter groups for the optimizer, a mapping between parameters and names need to be collected as well for mapping params with delta params in the online merging optimizers. The mapping is Tuple[(group_idx:int, param_idx:int)] -> name:str, group_idx is used to distinct different param groups such as params with or without weight decay, param_idx denotes the index of a specific param in the group.
- Passing the base and reference models along with the arguments for online merging optimizers: Before the optimization, online merging optimizers require an initialization of delta parameters by passing base and reference models to the
init_ref_param_difffunction.
def _create_online_merging_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
base_model, # base model (pt model that reference model trained from)
ref_model, # reference model (initial policy)
om_mode='ondare', # ondare or onties
) -> "torch.optim.Optimizer":
from online_merging import OnDAREAdamW, OnTIESAdamW
if om_mode == 'ondare':
optimizer_cls = ODAdamW
elif om_mode == 'onties':
optimizer_cls = OTAdamW
else:
raise ValueError(f"{finetuning_args.om_mode} is not supported yet.")
param_name_map = {}
param_w_decay = {"params": [], "weight_decay": training_args.weight_decay}
param_wo_decay = {"params": [], "weight_decay": 0.0}
for n, p in model.named_parameters():
if p.requires_grad:
param_w_decay["params"].append(p)
# build a param_name_map when caching the param group for the optimizer
param_name_map[(0, len(param_w_decay["params"]) - 1)] = n
optimizer_grouped_parameters = [param_w_decay, param_wo_decay]
optimizer_kwargs = {
"lr": training_args.learning_rate,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
"eps": training_args.adam_epsilon,
}
# passing the parameters for online merging optimizers
optimizer_kwargs.update({
"param_name_map": param_name_map,
"reserve_p": finetuning_args.reserve_p,
"alpha": finetuning_args.alpha,
"use_merge": finetuning_args.use_merge,
"rescale": finetuning_args.rescale,
"online_step": finetuning_args.online_step,
})
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
optimizer.init_ref_param_diff(ref_model, base_model)
del base_model
return optimizerExperiments in our manuscript
Our experiments are run with LLaMa-Factory(llmtuner)==0.7.0. You can simply install the LLaMa-Factory in this repository by cd LLaMa-Factory && pip install -e ..
The script for running DPO/IPO/KTO with online merging optimizers (LLaMa-Factory/run_dpo_om.sh) in LLaMa-Factory as followed:
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export WANDB_DISABLED=true
MODEL_SIZE=7B
BETA=0.1
LR=1e-6
OPTIMIZER=ondare
USE_EMA=True
USE_RESCALE=False
RESERVE_P=0.01
ALPHA=1e-7
DATASET=ultrafeedback_binarized
POSTFIX=
MODEL_PATH="reference model path" # set your reference model
BASE_MODEL_PATH="base model path" # set your base model
OUTPUT_DIR=saves/Qwen-${MODEL_SIZE}/dpo/${DATASET}_${OPTIMIZER}_${DROP_RATE}_ema_${USE_EMA}_rescale_${USE_RESCALE}_shrink_${SHRINK_BASE}_alpha_${ALPHA}_beta_${BETA}_lr_${LR}$POSTFIX
accelerate launch \
--config_file examples/accelerate/fsdp_config.yaml \
src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path $MODEL_PATH \
--dataset $DATASET \
--dataset_dir data \
--template qwen \
--dpo_beta $BETA \
--dpo_loss sigmoid \
--ref_model $MODEL_PATH \
--finetuning_type full \
--output_dir $OUTPUT_DIR \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 2048 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 100 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy no \
--learning_rate $LR \
--num_train_epochs 2.0 \
--ddp_timeout 180000000 \
--plot_loss \…Excerpt shown — open the source for the full document.