Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer: add predict with generate #32346

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from transformers import AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig,TrainingArguments, Trainer
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
from datasets import load_dataset
from peft import LoraConfig
import torch
import time
from torchmetrics.text import SQuAD
from random import randrange
from transformers.utils import logging
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

# timdettmers/openassistant-guanaco
# Stanford/web_questions
eval_dataset = load_dataset("timdettmers/openassistant-guanaco", split="test")


# train_dataset=dataset["train"]
# eval_dataset=dataset["test"]

# eval_dataset = eval_dataset.select(range(256))

quant_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_type=torch.bfloat16
)

#model_id = "meta-llama/Llama-2-7b-chat-hf"
model_id="openai-community/gpt2"
model = AutoModelForCausalLM.from_pretrained(
model_id,
#quantization_config=quant_config,
device_map="auto",
torch_dtype=torch.float16,
#attn_implementation="flash_attention_2",
)

print(f"Param count: {sum(p.numel() for p in model.parameters())}")

tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer.add_special_tokens({"pad_token":"</s>"})
pad_token_id = tokenizer.pad_token_id
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)


gen_config = model.generation_config
gen_config.max_new_tokens = 200
gen_config.use_cache = True


peft_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model.add_adapter(peft_config)
model.enable_adapters()


class DataCollatorForGenerationSFT:
def __init__(self, tokenizer, eval_mode=False):
self.tokenizer = tokenizer
self.eval_mode = eval_mode

def __call__(self, examples):
texts, texts_eval = [], []
for example in examples:
question = example["text"]
text_eval = question.split("### Assistant")[0]
texts.append(question.strip())
texts_eval.append(f"{text_eval.strip()}### Assistant:")

# Make sure we have right padding in train and left padding for eval parts
tokenizer.padding_side = "right"
batch = tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=512)

if self.eval_mode:
tokenizer.padding_side = "left"
batch_eval = tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
batch['generation_input_ids'] = batch_eval['input_ids']
batch['generation_attention_mask'] = batch_eval['attention_mask']
labels = batch["input_ids"].clone()
labels[labels == tokenizer.pad_token_id] = -100 # Ignore index for CE-loss
batch["labels"] = labels
return batch


def custom_metrics(prediction_dict):
# unmask for correct detokenization, because preds are padded to max length with -100
preds = prediction_dict.predictions
preds[preds == -100] = pad_token_id
lbls = prediction_dict.label_ids
lbls[lbls == -100] = pad_token_id

# Decode and do magic for metrics
preds = tokenizer.batch_decode(preds,skip_special_tokens=True)
lbls = tokenizer.batch_decode(lbls,skip_special_tokens=True)
return {"exact_match" : 0, "f1_score": 0}


training_args = TrainingArguments(
per_device_train_batch_size=8,
#per_device_eval_batch_size=128,
num_train_epochs=20,
do_train=True,
do_eval=True,
eval_strategy="steps",
eval_steps=500,
save_steps=500000,
bf16=True,
output_dir="test_predict",
overwrite_output_dir=True,
optim="adafactor",
report_to="wandb",
logging_steps=100000,
remove_unused_columns=False,
predict_with_generate=True,
generation_config=gen_config
)


trainer = Trainer(
model=model,
args=training_args,
data_collator=DataCollatorForGenerationSFT(tokenizer),
eval_data_collator=DataCollatorForGenerationSFT(tokenizer, eval_mode=True),
train_dataset=eval_dataset,
eval_dataset=eval_dataset,
compute_metrics=custom_metrics,
)

trainer.evaluate()
104 changes: 104 additions & 0 deletions seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from datasets import load_dataset
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments


raw_datasets = load_dataset("kde4", lang1="en", lang2="fr", trust_remote_code=True)
split_datasets = raw_datasets["train"].train_test_split(train_size=0.9, seed=20)

split_datasets["validation"] = split_datasets.pop("test")
split_datasets["train"][1]["translation"]

split_datasets["validation"] = split_datasets["validation"].select(range(256))


model_checkpoint = "google-t5/t5-base" # 11135332352 124439808 783150080 222903552
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

print(f"Param count: {sum(p.numel() for p in model.parameters())}")


tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="pt")

en_sentence = split_datasets["train"][1]["translation"]["en"]
fr_sentence = split_datasets["train"][1]["translation"]["fr"]

inputs = tokenizer(en_sentence, text_target=fr_sentence)


def preprocess_function(examples):
inputs = [ex["en"] for ex in examples["translation"]]
targets = [ex["fr"] for ex in examples["translation"]]
model_inputs = tokenizer(
inputs, text_target=targets, max_length=50, truncation=True
)
return model_inputs

tokenized_datasets = split_datasets.map(
preprocess_function,
batched=True,
remove_columns=split_datasets["train"].column_names,
)


from transformers import DataCollatorForSeq2Seq
import evaluate

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
batch = data_collator([tokenized_datasets["train"][i] for i in range(1, 3)])
batch.keys()

metric = evaluate.load("sacrebleu")


def compute_metrics(eval_preds):
preds, labels = eval_preds
# In case the model returns more than the prediction logits
if isinstance(preds, tuple):
preds = preds[0]

decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

# Replace -100s in the labels as we can't decode them
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [[label.strip()] for label in decoded_labels]

result = metric.compute(predictions=decoded_preds, references=decoded_labels)
return {"bleu": result["score"]}


args = Seq2SeqTrainingArguments(
f"marian-finetuned-kde4-en-to-fr",
evaluation_strategy="no",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=3,
predict_with_generate=True,
fp16=True,
push_to_hub=True,
generation_max_length=250,
)

from transformers import Seq2SeqTrainer

model.generation_config.max_new_tokens=200

trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)

trainer.evaluate()
Loading
Loading