diff --git a/examples/pytorch/multimodal_language_modeling.py/run_vlm.py b/examples/pytorch/multimodal_language_modeling.py/run_vlm.py new file mode 100644 index 00000000000000..04702ae5c9e06d --- /dev/null +++ b/examples/pytorch/multimodal_language_modeling.py/run_vlm.py @@ -0,0 +1,144 @@ +import random + +import numpy as np +import torch +from datasets import load_dataset +from Levenshtein import distance as levenshtein_distance +from peft import LoraConfig + +from transformers import ( + AutoProcessor, + BitsAndBytesConfig, + Idefics2ForConditionalGeneration, + Trainer, + TrainingArguments, +) + + +DEVICE = "cuda:0" +processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False) +pad_token_id = processor.tokenizer.pad_token_id + +lora_config = LoraConfig( + r=8, + lora_alpha=8, + lora_dropout=0.1, + target_modules=".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$", + use_dora=False, + init_lora_weights="gaussian", +) + +bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16) + +model = Idefics2ForConditionalGeneration.from_pretrained( + "HuggingFaceM4/idefics2-8b", + torch_dtype=torch.float16, + quantization_config=bnb_config, +) + +model.add_adapter(lora_config) +model.enable_adapters() + +train_dataset = load_dataset("nielsr/docvqa_1200_examples", split="train") +train_dataset = train_dataset.remove_columns(["id", "words", "bounding_boxes", "answer"]) + +eval_dataset = load_dataset("nielsr/docvqa_1200_examples", split="test") +eval_dataset = eval_dataset.remove_columns(["id", "words", "bounding_boxes", "answer"]) + + +class DataCollatorForGeneration: + def __init__(self, processor, eval_mode=False): + self.processor = processor + self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ + processor.tokenizer.additional_special_tokens.index("") + ] + self.eval_mode = eval_mode + + def __call__(self, examples): + texts, texts_eval = [], [] + images = [] + for example in examples: + image = example["image"] + question = example["query"]["en"] + answer = random.choice(example["answers"]) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": question}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": answer}]}, + ] + text = processor.apply_chat_template(messages, add_generation_prompt=False) + text_eval = processor.apply_chat_template([messages[0]], add_generation_prompt=True) + texts.append(text.strip()) + texts_eval.append(text_eval.strip()) + images.append([image]) + + # Make sure we have right padding in train and left padding for eval parts + processor.tokenizer.padding_side = "right" + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + if self.eval_mode: + processor.tokenizer.padding_side = "left" + batch_eval = processor(text=texts_eval, images=images, return_tensors="pt", padding=True) + batch["generation_input_ids"] = batch_eval["input_ids"] + batch["generation_attention_mask"] = batch_eval["attention_mask"] + + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id + batch["labels"] = labels + + return batch + + +def calculate_levenstein(prediction_dict): + # unmask for correct detokenization, because preds are padded to max length with -100 + preds = prediction_dict.predictions + preds[preds == -100] = pad_token_id + lbls = prediction_dict.label_ids + lbls[lbls == -100] = pad_token_id + + # Decode and do magic for metrics + preds = processor.batch_decode(preds) + lbls = processor.batch_decode(lbls) + levenstein_avg = np.mean([levenshtein_distance(pred, lbl) for pred, lbl in zip(preds, lbls)]) + return {"eval_levenstein": levenstein_avg} + + +generation_config = model.generation_config +generation_config.max_length = 200 # generate no more than 200 tokens (it includes image tokens also) + +training_args = TrainingArguments( + max_steps=1000, + per_device_train_batch_size=4, + per_device_eval_batch_size=8, + gradient_accumulation_steps=2, + output_dir="/raid/raushan/idefics-train", + eval_strategy="steps", + fp16=True, + eval_steps=10, + save_steps=10, + remove_unused_columns=False, + report_to="wandb", + predict_with_generate=True, # will generate in eval step so we can compute text-based metrics + generation_config=generation_config, + metric_for_best_model="levenstein", # will save model with lowest levenstein + greater_is_better=False, +) + + +trainer = Trainer( + model=model, + args=training_args, + data_collator=DataCollatorForGeneration(processor), + eval_data_collator=DataCollatorForGeneration(processor, eval_mode=True), + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=calculate_levenstein, +) + +trainer.train() # will run train and eval on the model diff --git a/src/transformers/models/idefics/configuration_idefics.py b/src/transformers/models/idefics/configuration_idefics.py index 56b6025a8e89dd..7e66721189e980 100644 --- a/src/transformers/models/idefics/configuration_idefics.py +++ b/src/transformers/models/idefics/configuration_idefics.py @@ -236,6 +236,7 @@ class IdeficsConfig(PretrainedConfig): model_type = "idefics" is_composition = False + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/idefics2/configuration_idefics2.py b/src/transformers/models/idefics2/configuration_idefics2.py index 1333895407e6e5..618cca9a54023c 100644 --- a/src/transformers/models/idefics2/configuration_idefics2.py +++ b/src/transformers/models/idefics2/configuration_idefics2.py @@ -213,6 +213,7 @@ class Idefics2Config(PretrainedConfig): model_type = "idefics2" is_composition = True + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index f2338a7c5a5df7..db0996d98b0e1f 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -74,6 +74,7 @@ class LlavaConfig(PretrainedConfig): model_type = "llava" is_composition = False + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/llava_next/configuration_llava_next.py b/src/transformers/models/llava_next/configuration_llava_next.py index e8768dde85722b..e21b868e485521 100644 --- a/src/transformers/models/llava_next/configuration_llava_next.py +++ b/src/transformers/models/llava_next/configuration_llava_next.py @@ -79,6 +79,7 @@ class LlavaNextConfig(PretrainedConfig): model_type = "llava_next" is_composition = False + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py index 3f310565b43747..e3b479a6219a2a 100644 --- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py +++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py @@ -94,6 +94,7 @@ class LlavaNextVideoConfig(PretrainedConfig): model_type = "llava_next_video" is_composition = True + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/llava_next_video/diff_llava_next_video.py b/src/transformers/models/llava_next_video/diff_llava_next_video.py index b4018db586e74e..dbcfa80615b2d2 100644 --- a/src/transformers/models/llava_next_video/diff_llava_next_video.py +++ b/src/transformers/models/llava_next_video/diff_llava_next_video.py @@ -103,6 +103,7 @@ class LlavaNextVideoConfig(PretrainedConfig): model_type = "llava_next_video" is_composition = True + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/paligemma/configuration_paligemma.py b/src/transformers/models/paligemma/configuration_paligemma.py index 2f5a72bb6f7889..2f7d29ebe05972 100644 --- a/src/transformers/models/paligemma/configuration_paligemma.py +++ b/src/transformers/models/paligemma/configuration_paligemma.py @@ -74,6 +74,7 @@ class PaliGemmaConfig(PretrainedConfig): model_type = "paligemma" is_composition = False + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/video_llava/configuration_video_llava.py b/src/transformers/models/video_llava/configuration_video_llava.py index 8738a02585e039..7f48ca638c3a60 100644 --- a/src/transformers/models/video_llava/configuration_video_llava.py +++ b/src/transformers/models/video_llava/configuration_video_llava.py @@ -79,6 +79,7 @@ class VideoLlavaConfig(PretrainedConfig): model_type = "video_llava" is_composition = False + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/models/vipllava/configuration_vipllava.py b/src/transformers/models/vipllava/configuration_vipllava.py index f88be5adfba028..2f1dacfa1f7620 100644 --- a/src/transformers/models/vipllava/configuration_vipllava.py +++ b/src/transformers/models/vipllava/configuration_vipllava.py @@ -73,6 +73,7 @@ class VipLlavaConfig(PretrainedConfig): model_type = "vipllava" is_composition = False + keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c241fd4eb83c70..3fe743845667ca 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -60,8 +60,14 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow from .feature_extraction_sequence_utils import SequenceFeatureExtractor +from .generation.configuration_utils import GenerationConfig from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend -from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available +from .integrations.deepspeed import ( + deepspeed_init, + deepspeed_load_checkpoint, + is_deepspeed_available, + is_deepspeed_zero3_enabled, +) from .integrations.tpu import tpu_spmd_dataloader from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint @@ -306,9 +312,12 @@ class Trainer: The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. data_collator (`DataCollator`, *optional*): - The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will + The function to use to form a batch from a list of elements of `train_dataset`. Will default to [`default_data_collator`] if no `tokenizer` is provided, an instance of [`DataCollatorWithPadding`] otherwise. + eval_data_collator (`typing.Union[DataCollator, NoneType]`, *optional*): + The function to use to form a batch from a list of elements of `eval_dataset` and `test_dataset`. Will + default to `data_collator` if no `eval_data_collator` is provided. train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*): The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. @@ -380,6 +389,7 @@ def __init__( model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, + eval_data_collator: Optional[DataCollator] = None, train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, @@ -542,6 +552,7 @@ def __init__( else default_data_collator ) self.data_collator = data_collator if data_collator is not None else default_collator + self.eval_data_collator = eval_data_collator if eval_data_collator is not None else self.data_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer @@ -981,7 +992,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None if eval_dataset is not None else self.eval_dataset ) - data_collator = self.data_collator + data_collator = self.eval_data_collator if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") @@ -1023,7 +1034,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ - data_collator = self.data_collator + data_collator = self.eval_data_collator if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): test_dataset = self._remove_unused_columns(test_dataset, description="test") @@ -3736,6 +3747,7 @@ def evaluate( eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", + **gen_kwargs, ) -> Dict[str, float]: """ Run evaluation and returns metrics. @@ -3770,6 +3782,8 @@ def evaluate( metric_key_prefix (`str`, *optional*, defaults to `"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is "eval" (default) + gen_kwargs: + Additional `generate` specific kwargs. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The @@ -3785,10 +3799,32 @@ def evaluate( eval_dataset=_eval_dataset if override else eval_dataset_name, ignore_keys=ignore_keys, metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", + **gen_kwargs, ) metrics.update(dataset_metrics) return metrics + # Set generation-related kwargs + if self.args.predict_with_generate: + if self.args.generation_config is None: + # We assume the model can generate if predict-with-generate is True + # Therefore, generation_config should be available + self.gen_config = self.model.generation_config + elif isinstance(self.args.generation_config, GenerationConfig): + gen_config = self.args.generation_config + self.gen_config = copy.deepcopy(gen_config) # copy so we don't modify args.gen_config in-place + else: + # That means `args.generation_config` is passed as a Dict + self.gen_config = self.model.generation_config + _ = self.gen_config.update(**self.args.generation_config) + unused_kwargs = self.gen_config.update(**gen_kwargs) + if unused_kwargs: + logger.warning_once( + f"Following generation related kwargs were passed to `evaluate` but not used by `generate()`: " + f"{' '.join(unused_kwargs.keys())} .", + "Make sure there are no typos in the passed kwargs or do not pass unused kwargs.", + ) + # memory metrics - must set up as early as possible self._memory_tracker.start() @@ -3836,7 +3872,11 @@ def evaluate( return output.metrics def predict( - self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" + self, + test_dataset: Dataset, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "test", + **gen_kwargs, ) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. @@ -3854,6 +3894,8 @@ def predict( metric_key_prefix (`str`, *optional*, defaults to `"test"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "test_bleu" if the prefix is "test" (default) + gen_kwargs: + Additional `generate` specific kwargs. @@ -3870,6 +3912,27 @@ def predict( - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained labels). """ + # Set generation-related kwargs + if self.args.predict_with_generate: + if self.args.generation_config is None: + # We assume the model can generate if predict-with-generate is True + # Therefore, generation_config should be available + self.gen_config = self.model.generation_config + elif isinstance(self.args.generation_config, GenerationConfig): + gen_config = self.args.generation_config + self.gen_config = copy.deepcopy(gen_config) # copy so we don't modify args.gen_config in-place + else: + # That means `args.generation_config` is passed as a Dict + self.gen_config = self.model.generation_config + _ = self.gen_config.update(**self.args.generation_config) + unused_kwargs = self.gen_config.update(**gen_kwargs) + if unused_kwargs: + logger.warning_once( + f"Following generation related kwargs were passed to `evaluate` but not used by `generate()`: " + f"{' '.join(unused_kwargs.keys())} .", + "Make sure there are no typos in the passed kwargs or do not pass unused kwargs.", + ) + # memory metrics - must set up as early as possible self._memory_tracker.start() @@ -4137,6 +4200,7 @@ def prediction_step( inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, + **gen_kwargs, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on `model` using `inputs`. @@ -4156,12 +4220,29 @@ def prediction_step( ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. + gen_kwargs: + Additional `generate` specific kwargs. Return: Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + + # Prioroty: gen_kwargs > args.gen_config > model.generation_config > default GenerationConfig() + if self.args.predict_with_generate: + gen_config = self.gen_config + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + synced_gpus = gen_kwargs.get("synced_gpus", default_synced_gpus) + if len(gen_kwargs) > 0: + unused_kwargs = gen_config.update(**gen_kwargs) + if unused_kwargs: + logger.warning_once( + "Following generation related kwargs were passed to `prediction_step` but not " + f"used by `generate()`: {' '.join(unused_kwargs.keys())} .", + "Make sure there are no typos in the passed kwargs or do not pass unused kwargs.", + ) + # For CLIP-like models capable of returning loss values. # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` # is `True` in `model.forward`. @@ -4185,6 +4266,37 @@ def prediction_step( else: labels = None + # If the `generation_input_ids` was passed in inputs, the model can generate and we need to modify + # input keys. Otherwise, we don't know the `prompt` to generate from + if self.args.predict_with_generate and not prediction_loss_only: + generation_inputs = inputs.copy() + if "generation_input_ids" in generation_inputs: + # get inputs that are related to text and contain only generation prompt + generation_only_inputs = { + k.replace("generation_", ""): v for k, v in generation_inputs.items() if "generation_" in k + } + + # get common inputs that are not related to text, e.g. pixel-values + gen_keys = generation_only_inputs.keys() + generation_inputs_common = { + k: v + for k, v in generation_inputs.items() + if k.replace("generation_", "") not in gen_keys and "generation" not in k + } + generated_tokens = model.generate( + **generation_inputs_common, + **generation_only_inputs, + generation_config=gen_config, + synced_gpus=synced_gpus, + ) + else: + raise ValueError( + "`predict_with_generate` is set to `True` but no inputs are passed for generation. ", + "Make sure you have `generation_input_ids` and `generation_attention_mask`.", + ) + + # clean up inputs for loss from generation related input tensors if there are any before doing `forward` + inputs = {k: v for k, v in inputs.items() if "generation_" not in k} with torch.no_grad(): if is_sagemaker_mp_enabled(): raw_outputs = smp_forward_only(model, inputs) @@ -4230,6 +4342,9 @@ def prediction_step( if prediction_loss_only: return (loss, None, None) + if self.args.predict_with_generate and not prediction_loss_only: + return (loss, generated_tokens, labels) + logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index abc45cffe4aeea..081e9e2dc94498 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -307,7 +307,7 @@ def prediction_step( generation_inputs = { k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask") } - generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs) + generated_tokens = model.generate(**generation_inputs, **gen_kwargs) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # TODO: remove this hack when the legacy code that initializes generation_config from a model config is diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 6b587bdd65ae97..a8994766af06a7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -28,6 +28,7 @@ from packaging import version from .debug_utils import DebugOption +from .generation import GenerationConfig from .trainer_utils import ( EvaluationStrategy, FSDPOption, @@ -190,6 +191,7 @@ class OptimizerNames(ExplicitEnum): "deepspeed", "gradient_checkpointing_kwargs", "lr_scheduler_kwargs", + "generation_config", ] @@ -793,6 +795,12 @@ class TrainingArguments: eval_use_gather_object (`bool`, *optional*, defaults to `False`): Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch. + predict_with_generate (`bool`, *optional*, defaults to `False`): + Whether to use generate to calculate generative metrics (ROUGE, BLEU). + generation_config (Union[`~generation.GenerationConfig`, `Dict`], *optional*): + The [`~generation.GenerationConfig`] object that will be used during generation if `predict_with_generate` is set to `True`. + Arguments passed in GenerationConfig will have higher priority than model's generation config. Anything not set by this config + will fallback to `model.generation_config` by default. use_liger_kernel (`bool`, *optional*, defaults to `False`): Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training. @@ -1510,6 +1518,20 @@ class TrainingArguments: }, ) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + generation_config: Optional[Union[dict, str, GenerationConfig]] = field( + default=None, + metadata={ + "help": ( + "The GenerationConfig that will be used during prediction. Args from this config ", + "will have higher priority than model's generation config. Anything not set by this config ", + "will fallback to `model.generation_config`.", + ) + }, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: