From 53e9aacc1087bc1d6cf9a1c0d243e82ffd56e925 Mon Sep 17 00:00:00 2001 From: Anatoly Belikov Date: Thu, 14 Mar 2024 09:11:43 +0300 Subject: [PATCH] log loss per image (#7278) * log loss per image * add commandline param for per image loss logging * style * debug-loss -> debug_loss --------- Co-authored-by: Sayak Paul --- .../train_text_to_image_lora_sdxl.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index c4d7a8b0450b..9e81e499aac6 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -425,6 +425,11 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--debug_loss", + action="store_true", + help="debug loss for each image, if filenames are awailable in the dataset", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -603,6 +608,7 @@ def main(args): # Move unet, vae and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. unet.to(accelerator.device, dtype=weight_dtype) + if args.pretrained_vae_model_name_or_path is None: vae.to(accelerator.device, dtype=torch.float32) else: @@ -890,13 +896,17 @@ def preprocess_train(examples): tokens_one, tokens_two = tokenize_captions(examples) examples["input_ids_one"] = tokens_one examples["input_ids_two"] = tokens_two + if args.debug_loss: + fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename] + if fnames: + examples["filenames"] = fnames return examples with accelerator.main_process_first(): if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms - train_dataset = dataset["train"].with_transform(preprocess_train) + train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) @@ -905,7 +915,7 @@ def collate_fn(examples): crop_top_lefts = [example["crop_top_lefts"] for example in examples] input_ids_one = torch.stack([example["input_ids_one"] for example in examples]) input_ids_two = torch.stack([example["input_ids_two"] for example in examples]) - return { + result = { "pixel_values": pixel_values, "input_ids_one": input_ids_one, "input_ids_two": input_ids_two, @@ -913,6 +923,11 @@ def collate_fn(examples): "crop_top_lefts": crop_top_lefts, } + filenames = [example["filenames"] for example in examples if "filenames" in example] + if filenames: + result["filenames"] = filenames + return result + # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, @@ -1105,7 +1120,9 @@ def compute_time_ids(original_size, crops_coords_top_left): loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() - + if args.debug_loss and "filenames" in batch: + for fname in batch["filenames"]: + accelerator.log({"loss_for_" + fname: loss}, step=global_step) # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps