RepoQwen (Alibaba Cloud)Qwen (Alibaba Cloud)published May 28, 2024seen 6d

QwenLM/online_merging_optimizers

Python

Open original ↗

Captured source

source ↗
published May 28, 2024seen 6dcaptured 10hhttp 200method plain

QwenLM/online_merging_optimizers

Description: Implementations of online merging optimizers proposed by Online Merging Optimizers for Boosting Rewards and Mitigating Tax in Alignment

Language: Python

License: Apache-2.0

Stars: 83

Forks: 14

Open issues: 2

Created: 2024-05-28T03:57:37Z

Pushed: 2024-06-19T12:52:34Z

Default branch: main

Fork: no

Archived: no

README:

Online Merging Optimizers

*Keming Lu, Bowen Yu, Fei Huang, Yang Fan, Runji Lin, Chang Zhou*

Qwen, Alibaba Inc.

This is the repository contains core implementations of the online merging optimizers, proposed by [Online Merging Optimizers for Boosting Rewards and Mitigating Tax in Alignment]().

Update

Thanks for the interest in online merging optimizers! We are working on a PR to merge our optimizers into LLaMa-Factory. Stay tune!

Introduction

Effectively aligning Large Language Models (LLMs) with human-centric values while preventing the degradation of abilities acquired through Pre-training and Supervised Fine-tuning (SFT) poses a central challenge in Reinforcement Learning from Human Feedback (RLHF). In this paper, we first discover that interpolating RLHF and SFT model parameters can adjust the trade-off between human preference and basic capabilities, thereby reducing the alignment tax at the cost of alignment reward. Inspired by this, we propose integrating the RL policy and SFT models at each optimization step in RLHF to continuously regulate the training direction, introducing the Online Merging Optimizer. Specifically, we merge gradients with the parameter differences between SFT and pretrained models, effectively steering the gradient towards maximizing rewards in the direction of SFT optimization. We demonstrate that our optimizer works well with different LLM families, such as Qwen and LLaMA, across various model sizes ranging from 1.8B to 8B, various RLHF algorithms like DPO and KTO, and existing model merging methods. It significantly enhances alignment reward while mitigating alignment tax, achieving higher overall performance across 14 benchmarks. A more detailed manuscript in [paper](assets/online_merging_arxiv_review.pdf).

Installation

You can install the source codes from this repository

git clone https://github.com/QwenLM/online_merging_optimizers
cd online_merging
pip install -e .

Usage

Applying online merging optimizers

There are two major differences in the initialization of online merging optimizers compared with original AdamW:

  • Building a mapping between parameters and names: When collecting the parameter groups for the optimizer, a mapping between parameters and names need to be collected as well for mapping params with delta params in the online merging optimizers. The mapping is Tuple[(group_idx:int, param_idx:int)] -> name:str, group_idx is used to distinct different param groups such as params with or without weight decay, param_idx denotes the index of a specific param in the group.
  • Passing the base and reference models along with the arguments for online merging optimizers: Before the optimization, online merging optimizers require an initialization of delta parameters by passing base and reference models to the init_ref_param_diff function.
def _create_online_merging_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
base_model, # base model (pt model that reference model trained from)
ref_model, # reference model (initial policy)
om_mode='ondare', # ondare or onties
) -> "torch.optim.Optimizer":
from online_merging import OnDAREAdamW, OnTIESAdamW
if om_mode == 'ondare':
optimizer_cls = ODAdamW
elif om_mode == 'onties':
optimizer_cls = OTAdamW
else:
raise ValueError(f"{finetuning_args.om_mode} is not supported yet.")

param_name_map = {}
param_w_decay = {"params": [], "weight_decay": training_args.weight_decay}
param_wo_decay = {"params": [], "weight_decay": 0.0}
for n, p in model.named_parameters():
if p.requires_grad:
param_w_decay["params"].append(p)
# build a param_name_map when caching the param group for the optimizer
param_name_map[(0, len(param_w_decay["params"]) - 1)] = n

optimizer_grouped_parameters = [param_w_decay, param_wo_decay]

optimizer_kwargs = {
"lr": training_args.learning_rate,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
"eps": training_args.adam_epsilon,
}

# passing the parameters for online merging optimizers
optimizer_kwargs.update({
"param_name_map": param_name_map,
"reserve_p": finetuning_args.reserve_p,
"alpha": finetuning_args.alpha,
"use_merge": finetuning_args.use_merge,
"rescale": finetuning_args.rescale,
"online_step": finetuning_args.online_step,
})
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
optimizer.init_ref_param_diff(ref_model, base_model)
del base_model
return optimizer

Experiments in our manuscript

Our experiments are run with LLaMa-Factory(llmtuner)==0.7.0. You can simply install the LLaMa-Factory in this repository by cd LLaMa-Factory && pip install -e ..

The script for running DPO/IPO/KTO with online merging optimizers (LLaMa-Factory/run_dpo_om.sh) in LLaMa-Factory as followed:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export WANDB_DISABLED=true

MODEL_SIZE=7B
BETA=0.1
LR=1e-6
OPTIMIZER=ondare
USE_EMA=True
USE_RESCALE=False
RESERVE_P=0.01
ALPHA=1e-7
DATASET=ultrafeedback_binarized
POSTFIX=

MODEL_PATH="reference model path" # set your reference model
BASE_MODEL_PATH="base model path" # set your base model

OUTPUT_DIR=saves/Qwen-${MODEL_SIZE}/dpo/${DATASET}_${OPTIMIZER}_${DROP_RATE}_ema_${USE_EMA}_rescale_${USE_RESCALE}_shrink_${SHRINK_BASE}_alpha_${ALPHA}_beta_${BETA}_lr_${LR}$POSTFIX

accelerate launch \
--config_file examples/accelerate/fsdp_config.yaml \
src/train_bash.py \
--stage dpo \
--do_train \
--model_name_or_path $MODEL_PATH \
--dataset $DATASET \
--dataset_dir data \
--template qwen \
--dpo_beta $BETA \
--dpo_loss sigmoid \
--ref_model $MODEL_PATH \
--finetuning_type full \
--output_dir $OUTPUT_DIR \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 2048 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 100 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy no \
--learning_rate $LR \
--num_train_epochs 2.0 \
--ddp_timeout 180000000 \
--plot_loss \…

Excerpt shown — open the source for the full document.