Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added initial TPO implementation #1965

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ae1ab50
added initial TPO implementation
sahsaeedi Aug 24, 2024
26a6daa
Merge branch 'main' into tpo
sahsaeedi Aug 24, 2024
637d55a
Merge branch 'main' into tpo
sahsaeedi Aug 25, 2024
39d84ff
Merge branch 'main' into tpo
sahsaeedi Aug 26, 2024
54dcea8
fixed the address in the utils.py
sahsaeedi Aug 26, 2024
c37364d
Moved custom function from utils to tpo_trainer
sahsaeedi Aug 29, 2024
caa54c2
Merging origin
sahsaeedi Aug 29, 2024
47b125f
Merge branch 'main' into tpo
sahsaeedi Aug 29, 2024
4d532b7
Merge branch 'main' into tpo
sahsaeedi Aug 31, 2024
e166542
Merge branch 'main' into tpo
sahsaeedi Sep 3, 2024
6066843
Merge branch 'main' into tpo
sahsaeedi Sep 4, 2024
b37687c
Merge branch 'main' into tpo
sahsaeedi Sep 6, 2024
d8e5e67
Merge branch 'main' into tpo
sahsaeedi Sep 8, 2024
d50c0c8
Merge branch 'main' into tpo
sahsaeedi Sep 9, 2024
7185e7b
Merge branch 'main' into tpo
sahsaeedi Sep 10, 2024
9eeba98
Merge branch 'main' into tpo
sahsaeedi Sep 12, 2024
ad3ae91
Merge branch 'main' into tpo
sahsaeedi Sep 13, 2024
3d61f43
Merge branch 'main' into tpo
sahsaeedi Sep 15, 2024
ddbb8ff
Merge branch 'main' into tpo
sahsaeedi Sep 18, 2024
a1eba9a
Merge branch 'main' into tpo
sahsaeedi Sep 18, 2024
21b6136
Merge branch 'main' into tpo
sahsaeedi Sep 18, 2024
cf45eb8
Merge branch 'main' into tpo
sahsaeedi Sep 23, 2024
a5de315
Merge branch 'main' into tpo
sahsaeedi Sep 25, 2024
9d752e2
Merge branch 'main' into tpo
sahsaeedi Oct 1, 2024
0cd3619
Merge branch 'main' into tpo
sahsaeedi Oct 3, 2024
bc88bc2
Merge branch 'main' into tpo
sahsaeedi Oct 4, 2024
f74a156
Merge branch 'main' into tpo
sahsaeedi Oct 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions tests/test_tpo_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest

import torch
from datasets import Dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer

from trl import TPOConfig, TPOTrainer

from .testing_utils import require_peft

class TPOTrainerTester(unittest.TestCase):
def setUp(self):
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token

# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)

def _init_dummy_dataset(self):
# fmt: off
dummy_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
"[INST] How is the stock price? [/INST]",
"[INST] How is the stock price? [/INST] ",
],
"reference": [
"Hello! It's great to see you here!",
"I'm doing fantastic, thank you for asking! It's a beautiful day, and I'm feeling energized and ready to tackle whatever comes my way. How about you?",
"My name is Mary. It's nice to meet you!",
"My name is Mary. It's nice to meet you!",
"Python is often considered the best programming language due to its readability, versatility, and strong community support.",
"Python is often considered the best programming language due to its readability, versatility, and strong community support.",
"Python is often considered the best programming language due to its readability, versatility, and strong community support.",
"The stock price has increased by 5% today, reaching an all-time high of $150 per share.",
"The stock price has increased by 5% today, reaching an all-time high of $150 per share.",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Python",
"$46 as of 10am EST",
"46 as of 10am EST",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"Java",
" $46 as of 10am EST",
" 46 as of 10am EST",
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

@parameterized.expand(
[
["gpt2", "sigmoid"],
["t5", "hinge"],
["gpt2", "ipo"],
["t5", "ipo"],
["gpt2", "simpo"],
["t5", "simpo"],
]
)
def test_tpo_trainer(self, name, loss_type):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
# eval_strategy="steps",
is_three_preference=True,
beta=0.1,
loss_type=loss_type,
tpo_alpha=1.0,
report_to="none",
)

dummy_dataset = self._init_dummy_dataset()

if name == "gpt2":
model = self.model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
tokenizer = self.t5_tokenizer
training_args.is_encoder_decoder = True

trainer = TPOTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)

@require_peft
def test_tpo_trainer_with_lora(self):
from peft import LoraConfig

lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
is_three_preference=True,
# eval_strategy="steps",
beta=0.1,
tpo_alpha=1.0,
report_to="none",
)

dummy_dataset = self._init_dummy_dataset()

trainer = TPOTrainer(
model=self.model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)
92 changes: 92 additions & 0 deletions trl/trainer/tpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Literal, Optional

from transformers import TrainingArguments


@dataclass
class TPOConfig(TrainingArguments):
r"""
TPOConfig collects all training arguments related to the [`TPOTrainer`] class.

Using [`HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

Parameters:
max_length (`int`, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, defaults to `None`):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_target_length (`int`, defaults to `None`):
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
beta (`float`, defaults to 0.1):
The beta factor in TPO loss.
label_smoothing (`float`, defaults to 0):
The label smoothing factor. This argument is required if you want to use the default data collator.
loss_type (`str`, defaults to `sigmoid`):
The type of loss to use. This argument is required if you want to use the default data collator.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
tpo_alpha (`float`, defaults to `1.0`):
A hyperparameter that controls the strength of the BC regularizer in TPO training.
simpo_gamma (`float`, defaults to `0.5`):
A target reward margin for the SimPO loss, used only when the "simpo" option is enabled.
padding_value (`int`, defaults to `None`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, defaults to `False`):
Whether to sample and log generations during evaluation step.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
is_three_preference (`Optional[bool]`, `optional`, defaults to `None`):
We need to know if the dataset has three preferences. If is_three_preference sets to Ture, the dataset should also have a `reference` column.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model`.
model_init_kwargs (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
dataset_num_proc (`Optional[int]`, *optional*):
The number of workers to use to tokenize the data. Defaults to None.
"""

max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
max_target_length: Optional[int] = None

beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "sigmoid"
disable_dropout: bool = True
tpo_alpha: float = 1.0
simpo_gamma: float = 0.5

label_pad_token_id: int = -100
padding_value: int = None
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
is_three_preference: Optional[bool] = None

model_init_kwargs: Optional[Dict] = None

dataset_num_proc: Optional[int] = None

def __post_init__(self):
if self.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in TPOTrainer. Please use KTOTrainer.")
return super().__post_init__()
Loading