Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous RLHF: Faster and More Efficient Online DPO #2278

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
146 changes: 146 additions & 0 deletions examples/scripts/dpo_online_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage for a 4 GPU setup:

accelerate launch --num_processes 3 examples/scripts/dpo_online_async.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-tldr-online-dpo \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 16 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0
"""

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

from trl import (
AsyncOnlineDPOConfig,
AsyncOnlineDPOTrainer,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_quantization_config,
)
from trl.trainer.sync_online_dpo_trainer import SyncOnlineDPOTrainer


if __name__ == "__main__":
parser = TrlParser((ScriptArguments, AsyncOnlineDPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

# make sure using same base model
training_args.sft_model_path = model_config.model_name_or_path

torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)

ref_model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)

reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path,
num_labels=1,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
if tokenizer.pad_token_id is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

################
# Dataset
################
train_dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split).select(range(1000))
eval_dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_test_split)

def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""

def tokenize(element):
outputs = tokenizer(
element["prompt"],
padding=False,
)
return {"input_ids": outputs["input_ids"]}

return dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
num_proc=training_args.dataset_num_proc,
)

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)

if training_args.sync_fallback is True:
TrainerCls = SyncOnlineDPOTrainer
else:
TrainerCls = AsyncOnlineDPOTrainer

trainer = TrainerCls(
config=training_args,
processing_class=tokenizer,
policy=model,
ref_policy=ref_model,
reward_model=reward_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# generation_config = GenerationConfig(
# max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
# )
# completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
# trainer.add_callback(completions_callback)
trainer.train()

# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
# Windows support is partially supported with DeepSpeed https://github.com/microsoft/DeepSpeed/tree/master#windows
"deepspeed": ["deepspeed>=0.14.4; sys_platform != 'win32'"],
"diffusers": ["diffusers>=0.18.0"],
"async": ["vllm >= 0.6.0"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
Expand Down
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
"trainer": [
"AlignPropConfig",
"AlignPropTrainer",
"AsyncOnlineDPOConfig",
"AsyncOnlineDPOTrainer",
"AllTrueJudge",
"BaseBinaryJudge",
"BaseJudge",
Expand Down Expand Up @@ -138,6 +140,8 @@
from .trainer import (
AlignPropConfig,
AlignPropTrainer,
AsyncOnlineDPOConfig,
AsyncOnlineDPOTrainer,
AllTrueJudge,
BaseBinaryJudge,
BaseJudge,
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
_import_structure = {
"alignprop_config": ["AlignPropConfig"],
"alignprop_trainer": ["AlignPropTrainer"],
"async_online_dpo_config": ["AsyncOnlineDPOConfig"],
"async_online_dpo_trainer": ["AsyncOnlineDPOTrainer"],
"base": ["BaseTrainer"],
"bco_config": ["BCOConfig"],
"bco_trainer": ["BCOTrainer"],
Expand Down Expand Up @@ -85,6 +87,8 @@
if TYPE_CHECKING:
from .alignprop_config import AlignPropConfig
from .alignprop_trainer import AlignPropTrainer
from .async_online_dpo_config import AsyncOnlineDPOConfig
from .async_online_dpo_trainer import AsyncOnlineDPOTrainer
from .base import BaseTrainer
from .bco_config import BCOConfig
from .bco_trainer import BCOTrainer
Expand Down
69 changes: 69 additions & 0 deletions trl/trainer/async_online_dpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass
from typing import Literal

from ..trainer.utils import OnPolicyConfig


@dataclass
class AsyncOnlineDPOConfig(OnPolicyConfig):
r"""
Configuration class for the [`AsyncOnlineDPOTrainer`].

Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
Name of this experiment.
num_ppo_epochs (`int`, *optional*, defaults to `1`):
Number of updates to train on the same minibatch
learning_rate (`float`, *optional*, defaults to `5e-7`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
reward_model_path (`str`, defaults to `None`):
Path to the reward model.
beta (`float`, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
the [paper](https://huggingface.co/papers/2310.12036).
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
Type of loss to use. Possible values are:

- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.

sync_fallback (`bool`, defaults to `False`):
Whether to fallback from asynchronous training, switches from async_online_dpo_trainer to sync_online_dpo_trainer
vllm_device (`str`, *optional*, defaults to `None`):
device to put the vllm generation on, defaults to accelerate.num_processes + 1"
vllm_gpu_memory_utilization (`float`, defaults to 0.9)
the percentage of the GPU's memory for vllm to reserve, reduce if exection graph takes too much space

"""

exp_name: str = os.path.basename(__file__)[: -len(".py")]
num_ppo_epochs: int = 1
learning_rate: float = 5e-7
reward_model_path: str = None
beta: float = 0.1
loss_type: Literal["sigmoid", "ipo"] = "sigmoid"

sync_fallback: bool = False
vllm_device: str | None = None
vllm_gpu_memory_utilization: float = 0.9
Loading