From f4af389729e04e4a96fc1cec77492cd29c3b220c Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Wed, 23 Oct 2024 11:55:38 -0400 Subject: [PATCH 1/9] initial online dpo trainer as ppo trainer --- trl/trainer/async_online_dpo_config.py | 66 +++ trl/trainer/async_online_dpo_trainer.py | 713 ++++++++++++++++++++++++ 2 files changed, 779 insertions(+) create mode 100644 trl/trainer/async_online_dpo_config.py create mode 100644 trl/trainer/async_online_dpo_trainer.py diff --git a/trl/trainer/async_online_dpo_config.py b/trl/trainer/async_online_dpo_config.py new file mode 100644 index 0000000000..fb8c11f8c1 --- /dev/null +++ b/trl/trainer/async_online_dpo_config.py @@ -0,0 +1,66 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from trl.trainer import OnlineDPOConfig + + +@dataclass +class AsyncOnlineDPOConfig(OnlineDPOConfig): + r""" + Configuration class for the [`AsyncOnlineDPOTrainer`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + learning_rate (`float`, *optional*, defaults to `5e-7`): + Initial learning rate for [`AdamW`] optimizer. The default value replaces that of + [`~transformers.TrainingArguments`]. + reward_model_path (`Optional[str]`, *optional*, defaults to `None`): + Path to the reward model. + max_new_tokens (`int`, *optional*, defaults to `64`): + Maximum number of tokens to generate per completion. + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + missing_eos_penalty (`Optional[float]`, *optional*, defaults to `None`): + Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage + to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive + value. + beta (`float` or `list[float]`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in + the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is + selected for each new epoch and the last β is used for the rest of the epochs. + loss_type (`str`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + + dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + vllm_device (`str`, *optional*, defaults to `None`): + device to put the vllm generation on, defaults to accelerate.num_processes + 1" + vllm_gpu_memory_utilization (`float`, defaults to 0.9) + the percentage of the GPU's memory for vllm to reserve, reduce if exection graph takes too much space + + """ + + vllm_device: str | None = None + vllm_gpu_memory_utilization: float = 0.9 diff --git a/trl/trainer/async_online_dpo_trainer.py b/trl/trainer/async_online_dpo_trainer.py new file mode 100644 index 0000000000..e491b0622a --- /dev/null +++ b/trl/trainer/async_online_dpo_trainer.py @@ -0,0 +1,713 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import math +import os +import textwrap +import time +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import broadcast, gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + TrainerControl, + is_wandb_available, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback + +from ..core import masked_mean, masked_whiten +from ..models.utils import unwrap_model_for_generation +from ..trainer.utils import ( + OnlineTrainerState, + batch_generation, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + prepare_deepspeed, + print_rich_table, + truncate_response, +) +from .ppo_config import PPOConfig +from .utils import generate_model_card + + +if is_wandb_available(): + import wandb + + +INVALID_LOGPROB = 1.0 + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, value_model) -> None: + super().__init__() + self.policy = policy + self.value_model = value_model + self.critic_backbone = getattr(value_model, value_model.base_model_prefix) + + def forward(self, **kwargs): + output = self.critic_backbone( + **kwargs, + ) + logits = self.value_model.score(output.hidden_states[-1]) + return self.policy(**kwargs), logits + + +class PPOTrainer(Trainer): + _tag_names = ["trl", "ppo"] + + def __init__( + self, + config: PPOConfig, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ], + policy: nn.Module, + ref_policy: nn.Module, + reward_model: nn.Module, + train_dataset: Dataset, + value_model: Optional[nn.Module] = None, + data_collator: Optional[DataCollatorWithPadding] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + # less commonly used + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + callbacks: Optional[List[TrainerCallback]] = None, + ) -> None: + if ref_policy is policy: + raise ValueError( + "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the " + "same as `policy`, you must mass a copy of it, or `None` if you use peft." + ) + + self.args = config + args = config + self.processing_class = processing_class + self.policy = policy + + self.policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + + self.ref_policy = ref_policy + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.value_model = value_model + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = ( + args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches + ) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) + self.local_dataloader_batch_size = args.local_batch_size + + ######### + # setup model, optimizer, and others + ######### + for module in [policy, ref_policy, value_model, reward_model]: + disable_dropout_in_model(module) + if args.stop_token and args.stop_token == "eos": + args.stop_token_id = processing_class.eos_token_id + self.model = PolicyAndValueWrapper(policy, value_model) + self.model.config = policy.config # needed for pushing to hub + self.create_optimizer_and_scheduler( + num_training_steps=args.num_total_batches + ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level + + ######### + ### trainer specifics + ######### + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + ######### + ### setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=DataCollatorWithPadding(self.processing_class), + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=DataCollatorWithPadding(self.processing_class), + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + else: + self.ref_policy = self.ref_policy.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + backup_model = self.model + self.model = self.model.policy # save only the policy + + if self.is_deepspeed_enabled: + backup_deepspeed = self.deepspeed + self.deepspeed = self.model + + super().save_model(output_dir, _internal_call) + + self.model = backup_model + + if self.is_deepspeed_enabled: + self.deepspeed = backup_deepspeed + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + ref_policy = self.ref_policy + reward_model = self.reward_model + processing_class = self.processing_class + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.response_length, + temperature=(args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training policy===") + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_total_batches * args.num_mini_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model + self.model_wrapped = self.model + + for update in range(1, args.num_total_batches + 1): + self.state.episode += 1 * args.batch_size + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + values = [] + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + query_responses, logitss = batch_generation( + unwrapped_model.policy, + queries, + args.local_rollout_forward_batch_size, + processing_class.pad_token_id, + generation_config, + ) + + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del logits, all_logprob + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 + unwrapped_value_model = accelerator.unwrap_model(model).value_model + full_value, _, _ = get_reward( + unwrapped_value_model, query_response, processing_class.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + values = torch.cat(values, 0) + del (logprob, ref_logprob, full_value, value, score, unwrapped_model) + torch.cuda.empty_cache() + gc.collect() + + # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -args.kl_coef * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + torch.cuda.empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = torch.masked_fill( + new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB + ) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) + vf_clipfrac = masked_mean( + (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] + ) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + args.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + pg_clipfrac + ) + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + vf_clipfrac + ) + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # del everything and empty cache + # fmt: off + del ( + output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, + vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, + pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, + mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, + ) + # fmt: on + torch.cuda.empty_cache() + with torch.no_grad(): + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + rlhf_reward = mean_non_score_reward + scores.mean() + eps = int(self.state.episode / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() + metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() + metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() + metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log + self.state.global_step += 1 + self.log(metrics) + + self.lr_scheduler.step() + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + torch.cuda.empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + torch.cuda.empty_cache() + del ( + query_responses, + responses, + postprocessed_responses, + logprobs, + ref_logprobs, + values, + sequence_lengths, + contain_eos_token, + sequence_lengths_p1, + response_idxs, + padding_mask, + padding_mask_p1, + rewards, + actual_start, + actual_end, + advantages, + returns, + ) + torch.cuda.empty_cache() + + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def generate_completions(self, sampling: bool = False): + args = self.args + processing_class = self.processing_class + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + table = defaultdict(list) + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( + unwrapped_model.policy, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + table["query"].extend( + gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) + ) + table["model response"].extend( + gather_object(processing_class.batch_decode(postprocessed_response)) + ) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + + if self.accelerator.is_main_process: + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or [] + if isinstance(tags, str): + tags = [tags] + + if hasattr(self.model.config, "unsloth_version"): + tags.append("unsloth") + + citation = textwrap.dedent("""\ + @article{mziegler2019fine-tuning, + title = {{Fine-Tuning Language Models from Human Preferences}}, + author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + year = 2019, + eprint = {arXiv:1909.08593} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="PPO", + trainer_citation=citation, + paper_title="Fine-Tuning Language Models from Human Preferences", + paper_id="1909.08593", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) From dfdfb2f78fc4651ccc2a28a461a21bdda23f44c7 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Wed, 23 Oct 2024 15:11:30 -0400 Subject: [PATCH 2/9] sync online dpo trainer --- trl/trainer/async_online_dpo_config.py | 31 +- ..._trainer.py => sync_online_dpo_trainer.py} | 358 ++++++++++-------- trl/trainer/utils.py | 7 + 3 files changed, 225 insertions(+), 171 deletions(-) rename trl/trainer/{async_online_dpo_trainer.py => sync_online_dpo_trainer.py} (66%) diff --git a/trl/trainer/async_online_dpo_config.py b/trl/trainer/async_online_dpo_config.py index fb8c11f8c1..2c73a9e194 100644 --- a/trl/trainer/async_online_dpo_config.py +++ b/trl/trainer/async_online_dpo_config.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List, Literal +import os -from trl.trainer import OnlineDPOConfig +from ..trainer.utils import OnPolicyConfig @dataclass -class AsyncOnlineDPOConfig(OnlineDPOConfig): +class AsyncOnlineDPOConfig(OnPolicyConfig): r""" Configuration class for the [`AsyncOnlineDPOTrainer`]. @@ -27,19 +29,15 @@ class AsyncOnlineDPOConfig(OnlineDPOConfig): command line. Parameters: + exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): + Name of this experiment. + num_ppo_epochs (`int`, *optional*, defaults to `1`): + Number of updates to train on the same minibatch learning_rate (`float`, *optional*, defaults to `5e-7`): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. reward_model_path (`Optional[str]`, *optional*, defaults to `None`): Path to the reward model. - max_new_tokens (`int`, *optional*, defaults to `64`): - Maximum number of tokens to generate per completion. - temperature (`float`, *optional*, defaults to `0.9`): - Temperature for sampling. The higher the temperature, the more random the completions. - missing_eos_penalty (`Optional[float]`, *optional*, defaults to `None`): - Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage - to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive - value. beta (`float` or `list[float]`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in @@ -51,10 +49,6 @@ class AsyncOnlineDPOConfig(OnlineDPOConfig): - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. - dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): - Number of processes to use for processing the dataset. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. vllm_device (`str`, *optional*, defaults to `None`): device to put the vllm generation on, defaults to accelerate.num_processes + 1" vllm_gpu_memory_utilization (`float`, defaults to 0.9) @@ -62,5 +56,12 @@ class AsyncOnlineDPOConfig(OnlineDPOConfig): """ + exp_name: str = os.path.basename(__file__)[: -len(".py")] + num_ppo_epochs: int = 1 + learning_rate: float = 5e-7 + reward_model_path: str = None + beta: List[float] = field(default_factory=lambda: [0.1]) + loss_type: Literal["sigmoid", "ipo"] = "sigmoid" + vllm_device: str | None = None vllm_gpu_memory_utilization: float = 0.9 diff --git a/trl/trainer/async_online_dpo_trainer.py b/trl/trainer/sync_online_dpo_trainer.py similarity index 66% rename from trl/trainer/async_online_dpo_trainer.py rename to trl/trainer/sync_online_dpo_trainer.py index e491b0622a..e279d0b028 100644 --- a/trl/trainer/async_online_dpo_trainer.py +++ b/trl/trainer/sync_online_dpo_trainer.py @@ -15,7 +15,6 @@ import gc import math import os -import textwrap import time from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -45,7 +44,6 @@ from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback -from ..core import masked_mean, masked_whiten from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( OnlineTrainerState, @@ -59,7 +57,7 @@ print_rich_table, truncate_response, ) -from .ppo_config import PPOConfig +from .async_online_dpo_config import AsyncOnlineDPOConfig from .utils import generate_model_card @@ -70,29 +68,12 @@ INVALID_LOGPROB = 1.0 -# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 -# we did this we can do a single `model = accelerator.prepare(model)` -class PolicyAndValueWrapper(nn.Module): - def __init__(self, policy, value_model) -> None: - super().__init__() - self.policy = policy - self.value_model = value_model - self.critic_backbone = getattr(value_model, value_model.base_model_prefix) - - def forward(self, **kwargs): - output = self.critic_backbone( - **kwargs, - ) - logits = self.value_model.score(output.hidden_states[-1]) - return self.policy(**kwargs), logits - - -class PPOTrainer(Trainer): - _tag_names = ["trl", "ppo"] +class SyncOnlineDPOTrainer(Trainer): + _tag_names = ["trl", "online-dpo"] def __init__( self, - config: PPOConfig, + config: AsyncOnlineDPOConfig, processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ], @@ -100,7 +81,6 @@ def __init__( ref_policy: nn.Module, reward_model: nn.Module, train_dataset: Dataset, - value_model: Optional[nn.Module] = None, data_collator: Optional[DataCollatorWithPadding] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, # less commonly used @@ -127,7 +107,6 @@ def __init__( self.reward_model = reward_model self.train_dataset = train_dataset self.train_dataset_len = len(train_dataset) - self.value_model = value_model self.data_collator = data_collator self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers @@ -166,17 +145,18 @@ def __init__( self.local_seed = args.seed + accelerator.process_index * 100003 # Prime if args.num_sample_generations > 0: self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) - self.local_dataloader_batch_size = args.local_batch_size + self.local_dataloader_batch_size = exact_div( + args.local_batch_size, 2, "`local_batch_size` must be a multiple of 2" + ) # Online DPO logic: needed because Online DPO repeats the same prompt 2 times ######### # setup model, optimizer, and others ######### - for module in [policy, ref_policy, value_model, reward_model]: + for module in [policy, ref_policy, reward_model]: disable_dropout_in_model(module) if args.stop_token and args.stop_token == "eos": args.stop_token_id = processing_class.eos_token_id - self.model = PolicyAndValueWrapper(policy, value_model) - self.model.config = policy.config # needed for pushing to hub + self.model = policy self.create_optimizer_and_scheduler( num_training_steps=args.num_total_batches ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level @@ -298,11 +278,15 @@ def repeat_generator(): stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) pg_clipfrac_stats = torch.zeros(stats_shape, device=device) - pg_loss_stats = torch.zeros(stats_shape, device=device) - vf_loss_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropy_stats = torch.zeros(stats_shape, device=device) - ratio_stats = torch.zeros(stats_shape, device=device) + loss_stats = torch.zeros(stats_shape, device=device) + chosen_reward_stats = torch.zeros(stats_shape, device=device) + chosen_logprobs_stats = torch.zeros(stats_shape, device=device) + chosen_ref_logprobs_stats = torch.zeros(stats_shape, device=device) + rejected_reward_stats = torch.zeros(stats_shape, device=device) + rejected_logprobs_stats = torch.zeros(stats_shape, device=device) + rejected_ref_logprobs_stats = torch.zeros(stats_shape, device=device) + # entropy_stats = torch.zeros(stats_shape, device=device) + # kl_stats = torch.zeros(stats_shape, device=device) model.train() # trainer state initialization @@ -338,6 +322,7 @@ def repeat_generator(): data = next(iter_dataloader) with torch.no_grad(): queries = data["input_ids"].to(device) + queries = queries.repeat(2, 1) context_length = queries.shape[1] responses = [] postprocessed_responses = [] @@ -383,11 +368,6 @@ def repeat_generator(): # Response Processing 2. run reward model on the truncated responses postprocessed_query_response = torch.cat((query, postprocessed_response), 1) sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 - unwrapped_value_model = accelerator.unwrap_model(model).value_model - full_value, _, _ = get_reward( - unwrapped_value_model, query_response, processing_class.pad_token_id, context_length - ) - value = full_value[:, context_length - 1 : -1].squeeze(-1) _, score, _ = get_reward( reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length ) @@ -398,15 +378,13 @@ def repeat_generator(): ref_logprobs.append(ref_logprob) sequence_lengths.append(sequence_length) scores.append(score) - values.append(value) responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) ref_logprobs = torch.cat(ref_logprobs, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) - values = torch.cat(values, 0) - del (logprob, ref_logprob, full_value, value, score, unwrapped_model) + del (logprob, ref_logprob, score, unwrapped_model) torch.cuda.empty_cache() gc.collect() @@ -422,36 +400,24 @@ def repeat_generator(): padding_mask = response_idxs > sequence_lengths.unsqueeze(1) logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - sequence_lengths_p1 = sequence_lengths + 1 - padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) - values = torch.masked_fill(values, padding_mask_p1, 0) # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -args.kl_coef * kl - rewards = non_score_reward.clone() - actual_start = torch.arange(rewards.size(0), device=rewards.device) - actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) - rewards[[actual_start, actual_end]] += scores - - # 5. whiten rewards - if args.whiten_rewards: - rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) - rewards = torch.masked_fill(rewards, padding_mask_p1, 0) - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = responses.shape[1] - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.gamma * args.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = masked_whiten(advantages, ~padding_mask) - advantages = torch.masked_fill(advantages, padding_mask, 0) + rlhf_reward = scores + num_examples = scores.size(0) // 2 + scores_reshaped = scores.reshape(2, num_examples).t() + + # Get the max scores and their local indices + chosen_scores, chosen_local_indices = torch.max(scores_reshaped, dim=1) + + # Get the min scores and their local indices + rejected_scores, rejected_local_indices = torch.min(scores_reshaped, dim=1) + scores_margin = chosen_scores - rejected_scores + + # Calculate the global indices + chosen_indices = chosen_local_indices * num_examples + torch.arange(num_examples, device=scores.device) + rejected_indices = rejected_local_indices * num_examples + torch.arange( + num_examples, device=scores.device + ) torch.cuda.empty_cache() # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch @@ -463,102 +429,189 @@ def repeat_generator(): mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] gradient_accumulation_idx = 0 for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): - with accelerator.accumulate(model): - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - mb_return = returns[micro_batch_inds] - mb_values = values[micro_batch_inds] - - output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - new_logprobs = torch.masked_fill( - new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + + ## chosen + chosen_mb_inds = chosen_indices[micro_batch_inds] + chosen_responses = responses[chosen_mb_inds] + + ## rejected + rejected_mb_inds = rejected_indices[micro_batch_inds] + rejected_responses = responses[rejected_mb_inds] + + concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) + concat_query_responses = query_responses[concat_mb_inds] + num_examples = chosen_mb_inds.shape[0] + with torch.no_grad(): + concat_ref_output = forward( + ref_policy, concat_query_responses, processing_class.pad_token_id + ) + chosen_ref_logits = concat_ref_output.logits[:num_examples] + rejected_ref_logits = concat_ref_output.logits[num_examples:] + + chosen_ref_logits = chosen_ref_logits[:, context_length - 1 : -1] + chosen_ref_logits /= args.temperature + 1e-7 + chosen_ref_all_logprobs = F.log_softmax(chosen_ref_logits, dim=-1) + chosen_ref_logprobs = torch.gather( + chosen_ref_all_logprobs, 2, chosen_responses.unsqueeze(-1) + ).squeeze(-1) + chosen_ref_logprobs = torch.masked_fill( + chosen_ref_logprobs, padding_mask[chosen_mb_inds], INVALID_LOGPROB + ) + chosen_ref_logprobs_sum = (chosen_ref_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + + rejected_ref_logits = rejected_ref_logits[:, context_length - 1 : -1] + rejected_ref_logits /= args.temperature + 1e-7 + rejected_ref_all_logprobs = F.log_softmax(rejected_ref_logits, dim=-1) + rejected_ref_logprobs = torch.gather( + rejected_ref_all_logprobs, 2, rejected_responses.unsqueeze(-1) + ).squeeze(-1) + rejected_ref_logprobs = torch.masked_fill( + rejected_ref_logprobs, padding_mask[rejected_mb_inds], INVALID_LOGPROB ) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) - vpredclipped = torch.clamp( - vpred, - mb_values - args.cliprange_value, - mb_values + args.cliprange_value, + rejected_ref_logprobs_sum = (rejected_ref_logprobs * ~padding_mask[rejected_mb_inds]).sum( + 1 ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss_max = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) - vf_clipfrac = masked_mean( - (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] + + ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum + + with accelerator.accumulate(model): + concat_output = forward(model, concat_query_responses, processing_class.pad_token_id) + chosen_logits = concat_output.logits[:num_examples] + rejected_logits = concat_output.logits[num_examples:] + + # chosen + chosen_logits = chosen_logits[:, context_length - 1 : -1] + chosen_logits /= args.temperature + 1e-7 + chosen_all_logprobs = F.log_softmax(chosen_logits, dim=-1) + chosen_logprobs = torch.gather( + chosen_all_logprobs, 2, chosen_responses.unsqueeze(-1) + ).squeeze(-1) + chosen_logprobs = torch.masked_fill( + chosen_logprobs, padding_mask[chosen_mb_inds], INVALID_LOGPROB ) - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) - pg_loss_max = torch.max(pg_losses, pg_losses2) - pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) - loss = pg_loss + args.vf_coef * vf_loss + # chosen_ref_logprobs = ref_logprobs[chosen_mb_inds] + chosen_logprobs_sum = (chosen_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + # chosen_ref_logprobs_sum = (chosen_ref_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + + # rejected + rejected_logits = rejected_logits[:, context_length - 1 : -1] + rejected_logits /= args.temperature + 1e-7 + rejected_all_logprobs = F.log_softmax(rejected_logits, dim=-1) + rejected_logprobs = torch.gather( + rejected_all_logprobs, 2, rejected_responses.unsqueeze(-1) + ).squeeze(-1) + rejected_logprobs = torch.masked_fill( + rejected_logprobs, padding_mask[rejected_mb_inds], INVALID_LOGPROB + ) + # rejected_ref_logprobs = ref_logprobs[rejected_mb_inds] + rejected_logprobs_sum = (rejected_logprobs * ~padding_mask[rejected_mb_inds]).sum(1) + # rejected_ref_logprobs_sum = (rejected_ref_logprobs * ~padding_mask[rejected_mb_inds]).sum( + # 1 + # ) + + pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum + + logits = pi_logratios - ref_logratios + + if self.loss_type == "sigmoid": + losses = -F.logsigmoid(self.beta * logits) + elif self.loss_type == "ipo": + losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.loss_type}") + + chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum).detach() + rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum).detach() + + loss = losses.mean() accelerator.backward(loss) optimizer.step() optimizer.zero_grad() + with torch.no_grad(): - pg_clipfrac = masked_mean( - (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss.detach() + chosen_reward_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + chosen_rewards.detach() + ) + chosen_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + chosen_logprobs_sum.detach() + ) + chosen_ref_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + chosen_ref_logprobs_sum.detach() ) - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - pg_clipfrac + rejected_reward_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + rejected_rewards.detach() ) - pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - vf_clipfrac + rejected_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + rejected_logprobs_sum.detach() ) - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + # entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = -logprobs.sum( + # 1 + # ).mean() gradient_accumulation_idx += 1 minibatch_idx += 1 # del everything and empty cache # fmt: off del ( - output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, - vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, - mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, + logits, loss, + concat_output, concat_query_responses, + chosen_logits, rejected_logits, + chosen_logprobs, rejected_logprobs, + chosen_responses, rejected_responses, + chosen_all_logprobs, rejected_all_logprobs, + concat_ref_output, + chosen_ref_logits, rejected_ref_logits, + chosen_ref_logprobs, rejected_ref_logprobs, + chosen_ref_all_logprobs, rejected_ref_all_logprobs, ) # fmt: on torch.cuda.empty_cache() with torch.no_grad(): - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - rlhf_reward = mean_non_score_reward + scores.mean() eps = int(self.state.episode / (time.time() - start_time)) metrics = {} metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() - metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() - metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() + metrics["loss/policy_avg"] = self.accelerator.gather(loss_stats).mean().item() metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() metrics["lr"] = self.lr_scheduler.get_last_lr()[0] metrics["episode"] = self.state.episode + + # dpo metrics + metrics["logps/chosen"] = self.accelerator.gather(chosen_logprobs_stats.mean()).mean().item() + metrics["logps/rejected"] = self.accelerator.gather(rejected_logprobs_stats.mean()).mean().item() + kl = ( + (chosen_logprobs_stats - chosen_ref_logprobs_stats) + + (rejected_logprobs_stats - rejected_ref_logprobs_stats) + ) / 2 + mean_kl = kl.mean() + self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item()) + non_score_reward = (-self.beta * kl).sum(1) + mean_non_score_reward = non_score_reward.mean() + self.stats["objective/non_score_reward"].append( + self.accelerator.gather(mean_non_score_reward).mean().item() + ) + rlhf_reward = scores + non_score_reward + self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item()) + logprobs_sum = (chosen_logprobs_stats + rejected_logprobs_stats) / 2 + mean_entropy = -logprobs_sum.mean() + self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item()) + self.stats["objective/scores_margin"].append( + self.accelerator.gather(scores_margin.mean()).mean().item() + ) + self.stats["rewards/chosen"].append(self.accelerator.gather(chosen_reward_stats.mean()).mean().item()) + self.stats["rewards/rejected"].append( + self.accelerator.gather(rejected_reward_stats.mean()).mean().item() + ) + margin = chosen_reward_stats - rejected_reward_stats + self.stats["rewards/margins"].append(self.accelerator.gather(margin.mean()).mean().item()) + accuracy = margin > 0 + self.stats["rewards/accuracies"].append(self.accelerator.gather(accuracy.float().mean()).mean().item()) + self.stats["beta"].append(self.beta) + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log self.state.global_step += 1 self.log(metrics) @@ -584,15 +637,8 @@ def repeat_generator(): values, sequence_lengths, contain_eos_token, - sequence_lengths_p1, response_idxs, padding_mask, - padding_mask_p1, - rewards, - actual_start, - actual_end, - advantages, - returns, ) torch.cuda.empty_cache() @@ -689,13 +735,13 @@ def create_model_card( if hasattr(self.model.config, "unsloth_version"): tags.append("unsloth") - citation = textwrap.dedent("""\ - @article{mziegler2019fine-tuning, - title = {{Fine-Tuning Language Models from Human Preferences}}, - author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, - year = 2019, - eprint = {arXiv:1909.08593} - }""") + # citation = textwrap.dedent("""\ + # @article{mziegler2019fine-tuning, + # title = {{Fine-Tuning Language Models from Human Preferences}}, + # author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + # year = 2019, + # eprint = {arXiv:1909.08593} + # }""") model_card = generate_model_card( base_model=base_model, @@ -704,10 +750,10 @@ def create_model_card( dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, - trainer_name="PPO", - trainer_citation=citation, - paper_title="Fine-Tuning Language Models from Human Preferences", - paper_id="1909.08593", + trainer_name="AsyncOnlineDPO", + # trainer_citation=citation, + # paper_title="Fine-Tuning Language Models from Human Preferences", + # paper_id="1909.08593", ) model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 96dc8fba24..22532d8ae2 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -46,15 +46,22 @@ is_torch_npu_available, is_torch_xpu_available, ) +from transformers.import_utils import _is_package_available from ..import_utils import is_unsloth_available from ..trainer.model_config import ModelConfig +_vllm_available = _is_package_available("vllm") + if is_peft_available(): from peft import LoraConfig, PeftConfig +def is_vllm_available(): + return _vllm_available + + class AdaptiveKLController: """ Adaptive KL controller described in the paper: From 3bc05d0ccf89708f94f4cd20152648f646e1a0a5 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Wed, 23 Oct 2024 17:28:18 -0400 Subject: [PATCH 3/9] async utils working --- trl/trainer/async_online_dpo_trainer.py | 845 ++++++++++++++++++++++++ trl/vllm_utils.py | 150 +++++ 2 files changed, 995 insertions(+) create mode 100644 trl/trainer/async_online_dpo_trainer.py create mode 100644 trl/vllm_utils.py diff --git a/trl/trainer/async_online_dpo_trainer.py b/trl/trainer/async_online_dpo_trainer.py new file mode 100644 index 0000000000..9202b2fa59 --- /dev/null +++ b/trl/trainer/async_online_dpo_trainer.py @@ -0,0 +1,845 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import logging +import math +import os +import queue +import threading +import time +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import broadcast, gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + TrainerControl, + is_wandb_available, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback + +from ..models.utils import unwrap_model_for_generation +from ..trainer.utils import ( + OnlineTrainerState, + batch_generation, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + is_vllm_available, + prepare_deepspeed, + print_rich_table, + truncate_response, +) +from .async_online_dpo_config import AsyncOnlineDPOConfig +from .utils import generate_model_card + + +if is_wandb_available(): + import wandb + +if is_vllm_available(): + from vllm import LLM, SamplingParams + + from ..vllm_utils import vllm_single_gpu_patch + + +INVALID_LOGPROB = 1.0 +logger = logging.getLogger(__name__) + + +class AsyncOnlineDPOTrainer(Trainer): + _tag_names = ["trl", "online-dpo", "async"] + + def __init__( + self, + config: AsyncOnlineDPOConfig, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ], + policy: nn.Module, + ref_policy: nn.Module, + reward_model: nn.Module, + train_dataset: Dataset, + data_collator: Optional[DataCollatorWithPadding] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + # less commonly used + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + callbacks: Optional[List[TrainerCallback]] = None, + ) -> None: + if not is_vllm_available(): + raise ImportError("`vllm` library is required for AsyncOnlineDPOTrainer, please install vllm") + + if ref_policy is policy: + raise ValueError( + "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the " + "same as `policy`, you must mass a copy of it, or `None` if you use peft." + ) + + self.args = config + args = config + self.processing_class = processing_class + self.policy = policy + + self.policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + + self.ref_policy = ref_policy + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = ( + args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches + ) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) + self.local_dataloader_batch_size = exact_div( + args.local_batch_size, 2, "`local_batch_size` must be a multiple of 2" + ) # Online DPO logic: needed because Online DPO repeats the same prompt 2 times + + ######### + # setup model, optimizer, and others + ######### + for module in [policy, ref_policy, reward_model]: + disable_dropout_in_model(module) + if args.stop_token and args.stop_token == "eos": + args.stop_token_id = processing_class.eos_token_id + self.model = policy + self.create_optimizer_and_scheduler( + num_training_steps=args.num_total_batches + ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level + + ######### + ### trainer specifics + ######### + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + ######### + ### setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=DataCollatorWithPadding(self.processing_class), + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=DataCollatorWithPadding(self.processing_class), + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + else: + self.ref_policy = self.ref_policy.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + backup_model = self.model + self.model = self.model.policy # save only the policy + + if self.is_deepspeed_enabled: + backup_deepspeed = self.deepspeed + self.deepspeed = self.model + + super().save_model(output_dir, _internal_call) + + self.model = backup_model + + if self.is_deepspeed_enabled: + self.deepspeed = backup_deepspeed + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + ref_policy = self.ref_policy + reward_model = self.reward_model + processing_class = self.processing_class + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + + accelerator.print("===training policy===") + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + loss_stats = torch.zeros(stats_shape, device=device) + chosen_reward_stats = torch.zeros(stats_shape, device=device) + chosen_logprobs_stats = torch.zeros(stats_shape, device=device) + chosen_ref_logprobs_stats = torch.zeros(stats_shape, device=device) + rejected_reward_stats = torch.zeros(stats_shape, device=device) + rejected_logprobs_stats = torch.zeros(stats_shape, device=device) + rejected_ref_logprobs_stats = torch.zeros(stats_shape, device=device) + model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_total_batches * args.num_mini_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model + self.model_wrapped = self.model + + if accelerator.is_main_process: + if args.fp16: + vllm_dtype = torch.float16 + elif args.bf16: + vllm_dtype = torch.bfloat16 + else: + vllm_dtype = torch.float32 + vllm_device = args.vllm_device or f"cuda:{accelerator.num_processes}" + response_ids_Q = queue.Queue(maxsize=1) + param_prompt_Q = queue.Queue(maxsize=1) + thread = threading.Thread( + target=vllm_generate, + args=( + args.sft_model_path, + vllm_device, + args.vllm_gpu_memory_utilization, + vllm_dtype, + response_ids_Q, + param_prompt_Q, + args.temperature, + args.response_length, + ), + ) + thread.start() + + data = next(iter_dataloader) + next_queries = data["input_ids"].to(device) + next_queries = next_queries.repeat(args.rloo_k, 1) + g_queries_list = gather_object(next_queries.tolist()) + if accelerator.is_main_process: + g_queries_list = [ + [inneritem for inneritem in item if inneritem != processing_class.pad_token_id] + for item in g_queries_list + ] # remove padding + param_prompt_Q.put((None, g_queries_list)) + + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + for update in range(1, args.num_total_batches + 1): + queries = next_queries + self.state.episode += 1 * args.batch_size + data = next(iter_dataloader) + vllm_responses = torch.zeros( + (args.batch_size * args.rloo_k, args.response_length), + device=accelerator.device, + dtype=torch.long, + ) + with torch.no_grad(): + next_queries = data["input_ids"].to(device) + next_queries = next_queries.repeat(args.rloo_k, 1) + + if self.args.sync: + queries = next_queries + + # with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + g_queries_list = gather_object(next_queries.tolist()) + if accelerator.is_main_process: + g_queries_list = [ + [inneritem for inneritem in item if inneritem != processing_class.pad_token_id] + for item in g_queries_list + ] # remove padding + + # send next queries to be generated + model_named_parameters = accelerator._get_named_parameters(model) + param_prompt_Q.put((model_named_parameters.items(), g_queries_list)) + + # get response for previous queries + g_response_ids = response_ids_Q.get() + + DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out + g_padded_response_ids = [ + list(response) + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) + for response in g_response_ids + ] + g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) + vllm_responses[:] = g_padded_response_ids + + broadcast(vllm_responses, 0) + local_vllm_responses = vllm_responses[ + accelerator.local_process_index * queries.shape[0] : (accelerator.local_process_index + 1) + * queries.shape[0] + ] + + context_length = queries.shape[1] + query_responses = torch.cat((queries, local_vllm_responses), 1) + responses = [] + postprocessed_responses = [] + scores = [] + sequence_lengths = [] + values = [] + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + sequence_lengths.append(sequence_length) + scores.append(score) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + del score + torch.cuda.empty_cache() + gc.collect() + + # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + + # 4. compute rewards + rlhf_reward = scores + num_examples = scores.size(0) // 2 + scores_reshaped = scores.reshape(2, num_examples).t() + + # Get the max scores and their local indices + chosen_scores, chosen_local_indices = torch.max(scores_reshaped, dim=1) + + # Get the min scores and their local indices + rejected_scores, rejected_local_indices = torch.min(scores_reshaped, dim=1) + scores_margin = chosen_scores - rejected_scores + + # Calculate the global indices + chosen_indices = chosen_local_indices * num_examples + torch.arange(num_examples, device=scores.device) + rejected_indices = rejected_local_indices * num_examples + torch.arange( + num_examples, device=scores.device + ) + torch.cuda.empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + + ## chosen + chosen_mb_inds = chosen_indices[micro_batch_inds] + chosen_responses = responses[chosen_mb_inds] + + ## rejected + rejected_mb_inds = rejected_indices[micro_batch_inds] + rejected_responses = responses[rejected_mb_inds] + + concat_mb_inds = torch.cat((chosen_mb_inds, rejected_mb_inds), dim=0) + concat_query_responses = query_responses[concat_mb_inds] + num_examples = chosen_mb_inds.shape[0] + + # reference logprobs + with torch.no_grad(): + concat_ref_output = forward( + ref_policy, concat_query_responses, processing_class.pad_token_id + ) + chosen_ref_logits = concat_ref_output.logits[:num_examples] + rejected_ref_logits = concat_ref_output.logits[num_examples:] + + chosen_ref_logits = chosen_ref_logits[:, context_length - 1 : -1] + chosen_ref_logits /= args.temperature + 1e-7 + chosen_ref_all_logprobs = F.log_softmax(chosen_ref_logits, dim=-1) + chosen_ref_logprobs = torch.gather( + chosen_ref_all_logprobs, 2, chosen_responses.unsqueeze(-1) + ).squeeze(-1) + chosen_ref_logprobs = torch.masked_fill( + chosen_ref_logprobs, padding_mask[chosen_mb_inds], INVALID_LOGPROB + ) + chosen_ref_logprobs_sum = (chosen_ref_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + + rejected_ref_logits = rejected_ref_logits[:, context_length - 1 : -1] + rejected_ref_logits /= args.temperature + 1e-7 + rejected_ref_all_logprobs = F.log_softmax(rejected_ref_logits, dim=-1) + rejected_ref_logprobs = torch.gather( + rejected_ref_all_logprobs, 2, rejected_responses.unsqueeze(-1) + ).squeeze(-1) + rejected_ref_logprobs = torch.masked_fill( + rejected_ref_logprobs, padding_mask[rejected_mb_inds], INVALID_LOGPROB + ) + rejected_ref_logprobs_sum = (rejected_ref_logprobs * ~padding_mask[rejected_mb_inds]).sum( + 1 + ) + + ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum + + with accelerator.accumulate(model): + concat_output = forward(model, concat_query_responses, processing_class.pad_token_id) + chosen_logits = concat_output.logits[:num_examples] + rejected_logits = concat_output.logits[num_examples:] + + # chosen + chosen_logits = chosen_logits[:, context_length - 1 : -1] + chosen_logits /= args.temperature + 1e-7 + chosen_all_logprobs = F.log_softmax(chosen_logits, dim=-1) + chosen_logprobs = torch.gather( + chosen_all_logprobs, 2, chosen_responses.unsqueeze(-1) + ).squeeze(-1) + chosen_logprobs = torch.masked_fill( + chosen_logprobs, padding_mask[chosen_mb_inds], INVALID_LOGPROB + ) + chosen_logprobs_sum = (chosen_logprobs * ~padding_mask[chosen_mb_inds]).sum(1) + + # rejected + rejected_logits = rejected_logits[:, context_length - 1 : -1] + rejected_logits /= args.temperature + 1e-7 + rejected_all_logprobs = F.log_softmax(rejected_logits, dim=-1) + rejected_logprobs = torch.gather( + rejected_all_logprobs, 2, rejected_responses.unsqueeze(-1) + ).squeeze(-1) + rejected_logprobs = torch.masked_fill( + rejected_logprobs, padding_mask[rejected_mb_inds], INVALID_LOGPROB + ) + rejected_logprobs_sum = (rejected_logprobs * ~padding_mask[rejected_mb_inds]).sum(1) + + pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum + + logits = pi_logratios - ref_logratios + + if self.loss_type == "sigmoid": + losses = -F.logsigmoid(self.beta * logits) + elif self.loss_type == "ipo": + losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.loss_type}") + + chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum).detach() + rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum).detach() + + loss = losses.mean() + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + with torch.no_grad(): + loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss.detach() + chosen_reward_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + chosen_rewards.mean().detach() + ) + chosen_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + chosen_logprobs_sum.mean().detach() + ) + chosen_ref_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + chosen_ref_logprobs_sum.mean().detach() + ) + rejected_reward_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + rejected_rewards.mean().detach() + ) + rejected_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + rejected_logprobs_sum.mean().detach() + ) + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # del everything and empty cache + # fmt: off + del ( + logits, loss, + concat_output, concat_query_responses, + chosen_logits, rejected_logits, + chosen_logprobs, rejected_logprobs, + chosen_responses, rejected_responses, + chosen_all_logprobs, rejected_all_logprobs, + concat_ref_output, + chosen_ref_logits, rejected_ref_logits, + chosen_ref_logprobs, rejected_ref_logprobs, + chosen_ref_all_logprobs, rejected_ref_all_logprobs, + ) + # fmt: on + torch.cuda.empty_cache() + with torch.no_grad(): + eps = int(self.state.episode / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather(loss_stats).mean().item() + metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + + # dpo metrics + metrics["logps/chosen"] = self.accelerator.gather(chosen_logprobs_stats.mean()).mean().item() + metrics["logps/rejected"] = self.accelerator.gather(rejected_logprobs_stats.mean()).mean().item() + kl = ( + (chosen_logprobs_stats - chosen_ref_logprobs_stats) + + (rejected_logprobs_stats - rejected_ref_logprobs_stats) + ) / 2 + mean_kl = kl.mean() + self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item()) + non_score_reward = (-self.beta * kl).sum(1) + mean_non_score_reward = non_score_reward.mean() + self.stats["objective/non_score_reward"].append( + self.accelerator.gather(mean_non_score_reward).mean().item() + ) + rlhf_reward = scores + non_score_reward + self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item()) + logprobs_sum = (chosen_logprobs_stats + rejected_logprobs_stats) / 2 + mean_entropy = -logprobs_sum.mean() + self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item()) + self.stats["objective/scores_margin"].append( + self.accelerator.gather(scores_margin.mean()).mean().item() + ) + self.stats["rewards/chosen"].append(self.accelerator.gather(chosen_reward_stats.mean()).mean().item()) + self.stats["rewards/rejected"].append( + self.accelerator.gather(rejected_reward_stats.mean()).mean().item() + ) + margin = chosen_reward_stats - rejected_reward_stats + self.stats["rewards/margins"].append(self.accelerator.gather(margin.mean()).mean().item()) + accuracy = margin > 0 + self.stats["rewards/accuracies"].append(self.accelerator.gather(accuracy.float().mean()).mean().item()) + self.stats["beta"].append(self.beta) + + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log + self.state.global_step += 1 + self.log(metrics) + + self.lr_scheduler.step() + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + torch.cuda.empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + torch.cuda.empty_cache() + del ( + query_responses, + responses, + postprocessed_responses, + values, + sequence_lengths, + contain_eos_token, + response_idxs, + padding_mask, + ) + torch.cuda.empty_cache() + + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def generate_completions(self, sampling: bool = False): + args = self.args + processing_class = self.processing_class + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + table = defaultdict(list) + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( + unwrapped_model.policy, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + table["query"].extend( + gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) + ) + table["model response"].extend( + gather_object(processing_class.batch_decode(postprocessed_response)) + ) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + + if self.accelerator.is_main_process: + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or [] + if isinstance(tags, str): + tags = [tags] + + if hasattr(self.model.config, "unsloth_version"): + tags.append("unsloth") + + # citation = textwrap.dedent("""\ + # @article{mziegler2019fine-tuning, + # title = {{Fine-Tuning Language Models from Human Preferences}}, + # author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + # year = 2019, + # eprint = {arXiv:1909.08593} + # }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="AsyncOnlineDPO", + # trainer_citation=citation, + # paper_title="Fine-Tuning Language Models from Human Preferences", + # paper_id="1909.08593", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) + + +def vllm_generate( + model_name_or_path: str, + vllm_device: str, + vllm_gpu_memory_utilization: float, + vllm_dtype: str, + response_ids_Q: queue.Queue, + param_prompt_Q: queue.Queue, + temperature: float, + response_length: int, +): + vllm_single_gpu_patch() + generation_config = SamplingParams( + temperature=(temperature + 1e-7), + top_p=1.0, + max_tokens=response_length, + include_stop_str_in_output=True, + ) + + llm = LLM( + model=model_name_or_path, + revision="main", + tokenizer_revision="main", + tensor_parallel_size=1, + device=vllm_device, + dtype=vllm_dtype, + gpu_memory_utilization=vllm_gpu_memory_utilization, + ) + logger.info(f"🔥🔥🔥 vllm loaded in {vllm_dtype}") + llmp = llm.llm_engine.model_executor.driver_worker.model_runner.model + i = 0 + while True: + i += 1 + model_named_parameters, g_queries_list = param_prompt_Q.get() + if model_named_parameters is None and g_queries_list is None: + logger.info( + "vllm thread received model params and queries = None, this indicates the end of training so exiting vllm thread" + ) + break + + if i > 2: + llmp.load_weights(model_named_parameters) + + outputs = llm.generate(prompt_token_ids=g_queries_list, sampling_params=generation_config, use_tqdm=False) + response_token_ids = [] + for output in outputs: + response_token_ids.append(output.outputs[0].token_ids) + + response_ids_Q.put(response_token_ids) diff --git a/trl/vllm_utils.py b/trl/vllm_utils.py new file mode 100644 index 0000000000..882e88546d --- /dev/null +++ b/trl/vllm_utils.py @@ -0,0 +1,150 @@ +# Taken and modified from https://github.com/allenai/openinstruct +# Taken and modified from https://github.com/huggingface/trl +# Copyright 2024 The AllenAI Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file basically allows us to place vLLM's driver worker in a specified +GPU. For example. you can try the following. +```python +from transformers import AutoTokenizer +from vllm import SamplingParams +from open_instruct.vllm_utils import SingleGPULLM +tok = AutoTokenizer.from_pretrained("facebook/opt-125m") +tok.chat_template = ( + "{% for message in messages %}" + "{{'\n\n' if not loop.first else ''}}" + "{{message['role']|capitalize + ': ' +message['content']}}" + "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" + "{% endfor %}" +) +prompts = [ + {"role": "user", "content": "Compose a speech about the need for more affordable dental care."}, +] +prompt_ids = tok.apply_chat_template(prompts, add_generation_prompt=True) +sampling_params = SamplingParams(temperature=0.001, top_p=1.0, max_tokens=1024, include_stop_str_in_output=True) +llm = SingleGPULLM(model="facebook/opt-125m", tensor_parallel_size=1, device="cuda:1") +llmp = llm.llm_engine.model_executor.driver_worker.model_runner.model +print(f"🔥🔥🔥 vllm lives in {llmp.lm_head.weight.device}") +print("prepare to generate") +outputs = llm.generate(prompt_token_ids=[prompt_ids], sampling_params=sampling_params) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +``` +""" + +from typing import List, Optional + +import torch +import vllm +from vllm.distributed.parallel_state import ( + GroupCoordinator, + get_world_group, + init_model_parallel_group, +) +from vllm.executor.gpu_executor import GPUExecutor + + +def custom_initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + world_size: int = 1 # SingleGPULLM logic: only use a single GPU + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + # global _TP + assert vllm.distributed.parallel_state._TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + vllm.distributed.parallel_state._TP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True + ) + + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + # global _PP + assert vllm.distributed.parallel_state._PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + vllm.distributed.parallel_state._PP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False + ) + + +def init_world_group(ranks: List[int], local_rank: int, backend: str) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[[0]], # SingleGPULLM logic: only use a single GPU + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_custom_allreduce=False, + use_tpu_communicator=False, + ) + + +def _init_executor(self) -> None: + """Initialize the worker and load the model.""" + assert self.parallel_config.world_size == 1, "GPUExecutor only supports single GPU." + + self.driver_worker = self._create_worker(local_rank=self.device_config.device.index) + self.driver_worker.init_device() + self.driver_worker.load_model() + + +# monkey patch the function +def vllm_single_gpu_patch(): + vllm.distributed.parallel_state.init_world_group = init_world_group + vllm.distributed.parallel_state.initialize_model_parallel = custom_initialize_model_parallel + GPUExecutor._init_executor = _init_executor From 44a2de5f90545845db192f987d7ba82e7250f5ae Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Wed, 23 Oct 2024 17:54:02 -0400 Subject: [PATCH 4/9] script and imports --- examples/scripts/dpo_online_async.py | 112 +++++++++++++++++++++++++ trl/__init__.py | 2 + trl/trainer/async_online_dpo_config.py | 2 +- trl/trainer/utils.py | 4 +- 4 files changed, 117 insertions(+), 3 deletions(-) create mode 100644 examples/scripts/dpo_online_async.py diff --git a/examples/scripts/dpo_online_async.py b/examples/scripts/dpo_online_async.py new file mode 100644 index 0000000000..5948b3c8c8 --- /dev/null +++ b/examples/scripts/dpo_online_async.py @@ -0,0 +1,112 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage for a 4 GPU setup: + +accelerate launch --num_processes 3 examples/scripts/dpo_online_async.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-online-dpo-async \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 16 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 +""" + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig + +from trl import ( + AsyncOnlineDPOConfig, + AsyncOnlineDPOTrainer, + LogCompletionsCallback, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, AsyncOnlineDPOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_and_config() + script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + training_args.sft_model_path = model_config.model_name_or_path + + torch_dtype = ( + model_config.torch_dtype + if model_config.torch_dtype in ["auto", None] + else getattr(torch, model_config.torch_dtype) + ) + quantization_config = get_quantization_config(model_config) + model_kwargs = dict( + revision=model_config.model_revision, + attn_implementation=model_config.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) + + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, + num_labels=1, + trust_remote_code=model_config.trust_remote_code, + **model_kwargs, + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, + padding_side="left", + trust_remote_code=model_config.trust_remote_code, + **model_kwargs, + ) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + dataset = load_dataset(script_args.dataset_name) + + trainer = AsyncOnlineDPOTrainer( + model=model, + reward_model=reward_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], + processing_class=tokenizer, + peft_config=get_peft_config(model_config), + ) + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/trl/__init__.py b/trl/__init__.py index 405991e652..afc9d1f0c4 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -50,6 +50,8 @@ "trainer": [ "AlignPropConfig", "AlignPropTrainer", + "AsyncOnlineDPOConfig", + "AsyncOnlineDPOTrainer", "BaseJudge", "BasePairwiseJudge", "BaseRankJudge", diff --git a/trl/trainer/async_online_dpo_config.py b/trl/trainer/async_online_dpo_config.py index 2c73a9e194..ccf4d435ee 100644 --- a/trl/trainer/async_online_dpo_config.py +++ b/trl/trainer/async_online_dpo_config.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from dataclasses import dataclass, field from typing import List, Literal -import os from ..trainer.utils import OnPolicyConfig diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 22532d8ae2..adf093db13 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1008,7 +1008,7 @@ class OnPolicyConfig(TrainingArguments): Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive value. - sft_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + sft_model_path (`str`, *optional*, defaults to None): Path to the SFT model. world_size (`Optional[int]`, *optional*, defaults to `None`): Number of processes (GPUs) to use for the training. @@ -1039,7 +1039,7 @@ class OnPolicyConfig(TrainingArguments): stop_token_id: Optional[int] = None temperature: float = 0.7 missing_eos_penalty: Optional[float] = None - sft_model_path: str = "EleutherAI/pythia-160m" + sft_model_path: Optional[str] = None world_size: Optional[int] = None num_total_batches: Optional[int] = None micro_batch_size: Optional[int] = None From f8b1b52eb588a55d38b96e1e284468fb9aca125f Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 24 Oct 2024 16:23:06 -0400 Subject: [PATCH 5/9] is it working --- trl/trainer/async_online_dpo_config.py | 9 ++-- trl/trainer/sync_online_dpo_trainer.py | 69 +++++++++++--------------- 2 files changed, 32 insertions(+), 46 deletions(-) diff --git a/trl/trainer/async_online_dpo_config.py b/trl/trainer/async_online_dpo_config.py index ccf4d435ee..2804d68929 100644 --- a/trl/trainer/async_online_dpo_config.py +++ b/trl/trainer/async_online_dpo_config.py @@ -36,13 +36,12 @@ class AsyncOnlineDPOConfig(OnPolicyConfig): learning_rate (`float`, *optional*, defaults to `5e-7`): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. - reward_model_path (`Optional[str]`, *optional*, defaults to `None`): + reward_model_path (`str`, defaults to `None`): Path to the reward model. - beta (`float` or `list[float]`, *optional*, defaults to `0.1`): + beta (`float`, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in - the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is - selected for each new epoch and the last β is used for the rest of the epochs. + the [paper](https://huggingface.co/papers/2310.12036). loss_type (`str`, *optional*, defaults to `"sigmoid"`): Type of loss to use. Possible values are: @@ -60,7 +59,7 @@ class AsyncOnlineDPOConfig(OnPolicyConfig): num_ppo_epochs: int = 1 learning_rate: float = 5e-7 reward_model_path: str = None - beta: List[float] = field(default_factory=lambda: [0.1]) + beta: float = 0.1 loss_type: Literal["sigmoid", "ipo"] = "sigmoid" vllm_device: str | None = None diff --git a/trl/trainer/sync_online_dpo_trainer.py b/trl/trainer/sync_online_dpo_trainer.py index e279d0b028..27997c35d5 100644 --- a/trl/trainer/sync_online_dpo_trainer.py +++ b/trl/trainer/sync_online_dpo_trainer.py @@ -110,6 +110,7 @@ def __init__( self.data_collator = data_collator self.eval_dataset = eval_dataset self.optimizer, self.lr_scheduler = optimizers + self.beta = config.beta ######### # calculate various batch sizes @@ -130,10 +131,6 @@ def __init__( args.local_mini_batch_size = exact_div( args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" ) - if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" # `per_rank_rollout_batch_size` is our `args.local_batch_size` # `per_rank_minibatch_size` is our `args.local_mini_batch_size` args.num_total_batches = math.ceil( @@ -145,9 +142,10 @@ def __init__( self.local_seed = args.seed + accelerator.process_index * 100003 # Prime if args.num_sample_generations > 0: self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) - self.local_dataloader_batch_size = exact_div( - args.local_batch_size, 2, "`local_batch_size` must be a multiple of 2" - ) # Online DPO logic: needed because Online DPO repeats the same prompt 2 times + + # To be similar to online_dpo_trainer.py, our batch size refers to the number of prompts + # This is unlike rloo_trainer.py where batch size is prompts * rloo_k (or in our case 2) + self.local_dataloader_batch_size = args.local_batch_size ######### # setup model, optimizer, and others @@ -236,7 +234,7 @@ def get_eval_dataloader(self) -> DataLoader: def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): backup_model = self.model - self.model = self.model.policy # save only the policy + self.model = self.model # save only the policy if self.is_deepspeed_enabled: backup_deepspeed = self.deepspeed @@ -333,7 +331,7 @@ def repeat_generator(): values = [] with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: query_responses, logitss = batch_generation( - unwrapped_model.policy, + unwrapped_model, queries, args.local_rollout_forward_batch_size, processing_class.pad_token_id, @@ -402,7 +400,6 @@ def repeat_generator(): ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) # 4. compute rewards - rlhf_reward = scores num_examples = scores.size(0) // 2 scores_reshaped = scores.reshape(2, num_examples).t() @@ -515,9 +512,9 @@ def repeat_generator(): logits = pi_logratios - ref_logratios - if self.loss_type == "sigmoid": + if self.args.loss_type == "sigmoid": losses = -F.logsigmoid(self.beta * logits) - elif self.loss_type == "ipo": + elif self.args.loss_type == "ipo": losses = (logits - 1 / (2 * self.beta)) ** 2 else: raise NotImplementedError(f"invalid loss type {self.loss_type}") @@ -533,23 +530,20 @@ def repeat_generator(): with torch.no_grad(): loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = loss.detach() chosen_reward_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - chosen_rewards.detach() + chosen_rewards.mean().detach() ) chosen_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - chosen_logprobs_sum.detach() + chosen_logprobs_sum.mean().detach() ) chosen_ref_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - chosen_ref_logprobs_sum.detach() + chosen_ref_logprobs_sum.mean().detach() ) rejected_reward_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - rejected_rewards.detach() + rejected_rewards.mean().detach() ) rejected_logprobs_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - rejected_logprobs_sum.detach() + rejected_logprobs_sum.mean().detach() ) - # entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = -logprobs.sum( - # 1 - # ).mean() gradient_accumulation_idx += 1 minibatch_idx += 1 # del everything and empty cache @@ -588,29 +582,22 @@ def repeat_generator(): + (rejected_logprobs_stats - rejected_ref_logprobs_stats) ) / 2 mean_kl = kl.mean() - self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item()) - non_score_reward = (-self.beta * kl).sum(1) - mean_non_score_reward = non_score_reward.mean() - self.stats["objective/non_score_reward"].append( - self.accelerator.gather(mean_non_score_reward).mean().item() - ) - rlhf_reward = scores + non_score_reward - self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item()) + metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() + mean_non_score_reward = (-self.beta * kl).mean() + metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() + mean_rlhf_reward = scores.mean() + mean_non_score_reward + metrics["objective/rlhf_reward"] = self.accelerator.gather(mean_rlhf_reward).mean().item() logprobs_sum = (chosen_logprobs_stats + rejected_logprobs_stats) / 2 mean_entropy = -logprobs_sum.mean() - self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item()) - self.stats["objective/scores_margin"].append( - self.accelerator.gather(scores_margin.mean()).mean().item() - ) - self.stats["rewards/chosen"].append(self.accelerator.gather(chosen_reward_stats.mean()).mean().item()) - self.stats["rewards/rejected"].append( - self.accelerator.gather(rejected_reward_stats.mean()).mean().item() - ) + metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() + metrics["objective/scores_margin"] = self.accelerator.gather(scores_margin.mean()).mean().item() + metrics["rewards/chosen"] = self.accelerator.gather(chosen_reward_stats.mean()).mean().item() + metrics["rewards/rejected"] = self.accelerator.gather(rejected_reward_stats.mean()).mean().item() margin = chosen_reward_stats - rejected_reward_stats - self.stats["rewards/margins"].append(self.accelerator.gather(margin.mean()).mean().item()) + metrics["rewards/margins"] = self.accelerator.gather(margin.mean()).mean().item() accuracy = margin > 0 - self.stats["rewards/accuracies"].append(self.accelerator.gather(accuracy.float().mean()).mean().item()) - self.stats["beta"].append(self.beta) + metrics["rewards/accuracies"] = self.accelerator.gather(accuracy.float().mean()).mean().item() + metrics["beta"] = self.beta self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log self.state.global_step += 1 @@ -621,7 +608,7 @@ def repeat_generator(): if self.control.should_save: self._save_checkpoint(model, trial=None, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, mean_rlhf_reward torch.cuda.empty_cache() gc.collect() @@ -666,7 +653,7 @@ def generate_completions(self, sampling: bool = False): with torch.no_grad(): context_length = query.shape[1] query_response, _ = batch_generation( - unwrapped_model.policy, + unwrapped_model, query, query.shape[0], processing_class.pad_token_id, From afdca5c476431f9a696698c824f75160f8f3e711 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 24 Oct 2024 16:23:56 -0400 Subject: [PATCH 6/9] sync running --- examples/scripts/dpo_online_async.py | 67 ++++++++++++++++++++-------- setup.py | 1 + trl/__init__.py | 2 + trl/trainer/__init__.py | 4 ++ trl/trainer/utils.py | 2 +- 5 files changed, 57 insertions(+), 19 deletions(-) diff --git a/examples/scripts/dpo_online_async.py b/examples/scripts/dpo_online_async.py index 5948b3c8c8..86beba8792 100644 --- a/examples/scripts/dpo_online_async.py +++ b/examples/scripts/dpo_online_async.py @@ -19,7 +19,7 @@ --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ --dataset_name trl-lib/tldr \ --learning_rate 5.0e-7 \ - --output_dir pythia-1b-tldr-online-dpo-async \ + --output_dir pythia-1b-tldr-online-dpo \ --per_device_train_batch_size 8 \ --gradient_accumulation_steps 16 \ --warmup_ratio 0.1 \ @@ -27,6 +27,7 @@ """ import torch +from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig @@ -41,7 +42,7 @@ get_peft_config, get_quantization_config, ) -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE +from trl.trainer.sync_online_dpo_trainer import SyncOnlineDPOTrainer if __name__ == "__main__": @@ -49,6 +50,7 @@ script_args, training_args, model_config = parser.parse_args_and_config() script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + # make sure using same base model training_args.sft_model_path = model_config.model_name_or_path torch_dtype = ( @@ -70,6 +72,10 @@ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) + reward_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, num_labels=1, @@ -83,27 +89,52 @@ trust_remote_code=model_config.trust_remote_code, **model_kwargs, ) - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE if tokenizer.pad_token_id is None: - tokenizer.pad_token = tokenizer.eos_token + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - dataset = load_dataset(script_args.dataset_name) + ################ + # Dataset + ################ + train_dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split).select(range(1000)) + eval_dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_test_split) - trainer = AsyncOnlineDPOTrainer( - model=model, - reward_model=reward_model, - args=training_args, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split], + def prepare_dataset(dataset, tokenizer): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + outputs = tokenizer( + element["prompt"], + padding=False, + ) + return {"input_ids": outputs["input_ids"]} + + return dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=training_args.dataset_num_proc, + ) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + + trainer = SyncOnlineDPOTrainer( + config=training_args, processing_class=tokenizer, - peft_config=get_peft_config(model_config), - ) - generation_config = GenerationConfig( - max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + policy=model, + ref_policy=ref_model, + reward_model=reward_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, ) - completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) - trainer.add_callback(completions_callback) + # generation_config = GenerationConfig( + # max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + # ) + # completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + # trainer.add_callback(completions_callback) trainer.train() # Save and push to hub diff --git a/setup.py b/setup.py index ffeb402edf..0693efb8df 100644 --- a/setup.py +++ b/setup.py @@ -103,6 +103,7 @@ "deepspeed": ["deepspeed>=0.14.4"], "quantization": ["bitsandbytes<=0.41.1"], "llm_judge": ["openai>=1.23.2", "llm-blender>=0.0.2"], + "async": ["vllm >= 0.6.0"], } EXTRAS["dev"] = [] for reqs in EXTRAS.values(): diff --git a/trl/__init__.py b/trl/__init__.py index afc9d1f0c4..b21b3b5ef4 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -148,6 +148,8 @@ from .trainer import ( AlignPropConfig, AlignPropTrainer, + AsyncOnlineDPOConfig, + AsyncOnlineDPOTrainer, BaseJudge, BasePairwiseJudge, BaseRankJudge, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index f0eba412c6..c0f0da276d 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -21,6 +21,8 @@ _import_structure = { "alignprop_config": ["AlignPropConfig"], "alignprop_trainer": ["AlignPropTrainer"], + "async_online_dpo_config": ["AsyncOnlineDPOConfig"], + "async_online_dpo_trainer": ["AsyncOnlineDPOTrainer"], "base": ["BaseTrainer"], "bco_config": ["BCOConfig"], "bco_trainer": ["BCOTrainer"], @@ -85,6 +87,8 @@ if TYPE_CHECKING: from .alignprop_config import AlignPropConfig from .alignprop_trainer import AlignPropTrainer + from .async_online_dpo_config import AsyncOnlineDPOConfig + from .async_online_dpo_trainer import AsyncOnlineDPOTrainer from .base import BaseTrainer from .bco_config import BCOConfig from .bco_trainer import BCOTrainer diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index adf093db13..dd35b76895 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -46,7 +46,7 @@ is_torch_npu_available, is_torch_xpu_available, ) -from transformers.import_utils import _is_package_available +from transformers.utils.import_utils import _is_package_available from ..import_utils import is_unsloth_available from ..trainer.model_config import ModelConfig From 97abfab24e6c3271bb4d57075b42081335d54fd4 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 24 Oct 2024 16:37:55 -0400 Subject: [PATCH 7/9] async running --- examples/scripts/dpo_online_async.py | 2 +- trl/trainer/async_online_dpo_trainer.py | 64 ++++++++++--------------- 2 files changed, 26 insertions(+), 40 deletions(-) diff --git a/examples/scripts/dpo_online_async.py b/examples/scripts/dpo_online_async.py index 86beba8792..f259dcd744 100644 --- a/examples/scripts/dpo_online_async.py +++ b/examples/scripts/dpo_online_async.py @@ -121,7 +121,7 @@ def tokenize(element): train_dataset = prepare_dataset(train_dataset, tokenizer) eval_dataset = prepare_dataset(eval_dataset, tokenizer) - trainer = SyncOnlineDPOTrainer( + trainer = AsyncOnlineDPOTrainer( config=training_args, processing_class=tokenizer, policy=model, diff --git a/trl/trainer/async_online_dpo_trainer.py b/trl/trainer/async_online_dpo_trainer.py index 9202b2fa59..ff33cad95b 100644 --- a/trl/trainer/async_online_dpo_trainer.py +++ b/trl/trainer/async_online_dpo_trainer.py @@ -109,6 +109,7 @@ def __init__( args = config self.processing_class = processing_class self.policy = policy + self.beta = config.beta self.policy.generation_config.eos_token_id = ( None # disable `pad_token_id` and `eos_token_id` because we just want to @@ -142,10 +143,6 @@ def __init__( args.local_mini_batch_size = exact_div( args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" ) - if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" # `per_rank_rollout_batch_size` is our `args.local_batch_size` # `per_rank_minibatch_size` is our `args.local_mini_batch_size` args.num_total_batches = math.ceil( @@ -157,9 +154,9 @@ def __init__( self.local_seed = args.seed + accelerator.process_index * 100003 # Prime if args.num_sample_generations > 0: self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) - self.local_dataloader_batch_size = exact_div( - args.local_batch_size, 2, "`local_batch_size` must be a multiple of 2" - ) # Online DPO logic: needed because Online DPO repeats the same prompt 2 times + # To be similar to online_dpo_trainer.py, our batch size refers to the number of prompts + # This is unlike rloo_trainer.py where batch size is prompts * rloo_k (or in our case 2) + self.local_dataloader_batch_size = args.local_batch_size ######### # setup model, optimizer, and others @@ -248,7 +245,7 @@ def get_eval_dataloader(self) -> DataLoader: def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): backup_model = self.model - self.model = self.model.policy # save only the policy + self.model = self.model # save only the policy if self.is_deepspeed_enabled: backup_deepspeed = self.deepspeed @@ -344,7 +341,7 @@ def repeat_generator(): data = next(iter_dataloader) next_queries = data["input_ids"].to(device) - next_queries = next_queries.repeat(args.rloo_k, 1) + next_queries = next_queries.repeat(2, 1) g_queries_list = gather_object(next_queries.tolist()) if accelerator.is_main_process: g_queries_list = [ @@ -359,16 +356,13 @@ def repeat_generator(): self.state.episode += 1 * args.batch_size data = next(iter_dataloader) vllm_responses = torch.zeros( - (args.batch_size * args.rloo_k, args.response_length), + (args.batch_size * 2, args.response_length), device=accelerator.device, dtype=torch.long, ) with torch.no_grad(): next_queries = data["input_ids"].to(device) - next_queries = next_queries.repeat(args.rloo_k, 1) - - if self.args.sync: - queries = next_queries + next_queries = next_queries.repeat(2, 1) # with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: g_queries_list = gather_object(next_queries.tolist()) @@ -449,7 +443,6 @@ def repeat_generator(): padding_mask = response_idxs > sequence_lengths.unsqueeze(1) # 4. compute rewards - rlhf_reward = scores num_examples = scores.size(0) // 2 scores_reshaped = scores.reshape(2, num_examples).t() @@ -558,12 +551,12 @@ def repeat_generator(): logits = pi_logratios - ref_logratios - if self.loss_type == "sigmoid": + if self.args.loss_type == "sigmoid": losses = -F.logsigmoid(self.beta * logits) - elif self.loss_type == "ipo": + elif self.args.loss_type == "ipo": losses = (logits - 1 / (2 * self.beta)) ** 2 else: - raise NotImplementedError(f"invalid loss type {self.loss_type}") + raise NotImplementedError(f"invalid loss type {self.args.loss_type}") chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum).detach() rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum).detach() @@ -626,29 +619,22 @@ def repeat_generator(): + (rejected_logprobs_stats - rejected_ref_logprobs_stats) ) / 2 mean_kl = kl.mean() - self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item()) - non_score_reward = (-self.beta * kl).sum(1) - mean_non_score_reward = non_score_reward.mean() - self.stats["objective/non_score_reward"].append( - self.accelerator.gather(mean_non_score_reward).mean().item() - ) - rlhf_reward = scores + non_score_reward - self.stats["objective/rlhf_reward"].append(self.accelerator.gather(rlhf_reward).mean().item()) + metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() + mean_non_score_reward = (-self.beta * kl).mean() + metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() + mean_rlhf_reward = scores.mean() + mean_non_score_reward + metrics["objective/rlhf_reward"] = self.accelerator.gather(mean_rlhf_reward).mean().item() logprobs_sum = (chosen_logprobs_stats + rejected_logprobs_stats) / 2 mean_entropy = -logprobs_sum.mean() - self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item()) - self.stats["objective/scores_margin"].append( - self.accelerator.gather(scores_margin.mean()).mean().item() - ) - self.stats["rewards/chosen"].append(self.accelerator.gather(chosen_reward_stats.mean()).mean().item()) - self.stats["rewards/rejected"].append( - self.accelerator.gather(rejected_reward_stats.mean()).mean().item() - ) + metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() + metrics["objective/scores_margin"] = self.accelerator.gather(scores_margin.mean()).mean().item() + metrics["rewards/chosen"] = self.accelerator.gather(chosen_reward_stats.mean()).mean().item() + metrics["rewards/rejected"] = self.accelerator.gather(rejected_reward_stats.mean()).mean().item() margin = chosen_reward_stats - rejected_reward_stats - self.stats["rewards/margins"].append(self.accelerator.gather(margin.mean()).mean().item()) + metrics["rewards/margins"] = self.accelerator.gather(margin.mean()).mean().item() accuracy = margin > 0 - self.stats["rewards/accuracies"].append(self.accelerator.gather(accuracy.float().mean()).mean().item()) - self.stats["beta"].append(self.beta) + metrics["rewards/accuracies"] = self.accelerator.gather(accuracy.float().mean()).mean().item() + metrics["beta"] = self.beta self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log self.state.global_step += 1 @@ -659,7 +645,7 @@ def repeat_generator(): if self.control.should_save: self._save_checkpoint(model, trial=None, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, mean_rlhf_reward torch.cuda.empty_cache() gc.collect() @@ -702,7 +688,7 @@ def generate_completions(self, sampling: bool = False): with torch.no_grad(): context_length = query.shape[1] query_response, _ = batch_generation( - unwrapped_model.policy, + unwrapped_model, query, query.shape[0], processing_class.pad_token_id, From 2dde1169d68c4cd29ba77b45faad23cb9ef4ad56 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Thu, 24 Oct 2024 16:45:23 -0400 Subject: [PATCH 8/9] all running and sync fallback --- examples/scripts/dpo_online_async.py | 8 ++++++-- trl/trainer/async_online_dpo_config.py | 7 +++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/scripts/dpo_online_async.py b/examples/scripts/dpo_online_async.py index f259dcd744..1c8d8f69fa 100644 --- a/examples/scripts/dpo_online_async.py +++ b/examples/scripts/dpo_online_async.py @@ -39,7 +39,6 @@ ScriptArguments, TrlParser, get_kbit_device_map, - get_peft_config, get_quantization_config, ) from trl.trainer.sync_online_dpo_trainer import SyncOnlineDPOTrainer @@ -121,7 +120,12 @@ def tokenize(element): train_dataset = prepare_dataset(train_dataset, tokenizer) eval_dataset = prepare_dataset(eval_dataset, tokenizer) - trainer = AsyncOnlineDPOTrainer( + if training_args.sync_fallback is True: + TrainerCls = SyncOnlineDPOTrainer + else: + TrainerCls = AsyncOnlineDPOTrainer + + trainer = TrainerCls( config=training_args, processing_class=tokenizer, policy=model, diff --git a/trl/trainer/async_online_dpo_config.py b/trl/trainer/async_online_dpo_config.py index 2804d68929..c6811b432c 100644 --- a/trl/trainer/async_online_dpo_config.py +++ b/trl/trainer/async_online_dpo_config.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from dataclasses import dataclass, field -from typing import List, Literal +from dataclasses import dataclass +from typing import Literal from ..trainer.utils import OnPolicyConfig @@ -48,6 +48,8 @@ class AsyncOnlineDPOConfig(OnPolicyConfig): - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + sync_fallback (`bool`, defaults to `False`): + Whether to fallback from asynchronous training, switches from async_online_dpo_trainer to sync_online_dpo_trainer vllm_device (`str`, *optional*, defaults to `None`): device to put the vllm generation on, defaults to accelerate.num_processes + 1" vllm_gpu_memory_utilization (`float`, defaults to 0.9) @@ -62,5 +64,6 @@ class AsyncOnlineDPOConfig(OnPolicyConfig): beta: float = 0.1 loss_type: Literal["sigmoid", "ipo"] = "sigmoid" + sync_fallback: bool = False vllm_device: str | None = None vllm_gpu_memory_utilization: float = 0.9 From 421f94224f66d1dfbfd8b5e7ac0c3cdd8ac4fc32 Mon Sep 17 00:00:00 2001 From: Michael Noukhovitch Date: Tue, 29 Oct 2024 18:48:14 -0400 Subject: [PATCH 9/9] precommit and ruff --- examples/scripts/dpo_online_async.py | 3 +-- trl/trainer/utils.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/dpo_online_async.py b/examples/scripts/dpo_online_async.py index 1c8d8f69fa..ddd33cd4ea 100644 --- a/examples/scripts/dpo_online_async.py +++ b/examples/scripts/dpo_online_async.py @@ -29,12 +29,11 @@ import torch from accelerate import PartialState from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from trl import ( AsyncOnlineDPOConfig, AsyncOnlineDPOTrainer, - LogCompletionsCallback, ModelConfig, ScriptArguments, TrlParser, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index dd35b76895..aae2964612 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -51,6 +51,7 @@ from ..import_utils import is_unsloth_available from ..trainer.model_config import ModelConfig + _vllm_available = _is_package_available("vllm")