Skip to content

Commit

Permalink
log loss per image (huggingface#7278)
Browse files Browse the repository at this point in the history
* log loss per image

* add commandline param for per image loss logging

* style

* debug-loss -> debug_loss

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
noskill and sayakpaul authored Mar 14, 2024
1 parent 4142446 commit 53e9aac
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--debug_loss",
action="store_true",
help="debug loss for each image, if filenames are awailable in the dataset",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -603,6 +608,7 @@ def main(args):
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
unet.to(accelerator.device, dtype=weight_dtype)

if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32)
else:
Expand Down Expand Up @@ -890,13 +896,17 @@ def preprocess_train(examples):
tokens_one, tokens_two = tokenize_captions(examples)
examples["input_ids_one"] = tokens_one
examples["input_ids_two"] = tokens_two
if args.debug_loss:
fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
if fnames:
examples["filenames"] = fnames
return examples

with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
Expand All @@ -905,14 +915,19 @@ def collate_fn(examples):
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
return {
result = {
"pixel_values": pixel_values,
"input_ids_one": input_ids_one,
"input_ids_two": input_ids_two,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}

filenames = [example["filenames"] for example in examples if "filenames" in example]
if filenames:
result["filenames"] = filenames
return result

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
Expand Down Expand Up @@ -1105,7 +1120,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()

if args.debug_loss and "filenames" in batch:
for fname in batch["filenames"]:
accelerator.log({"loss_for_" + fname: loss}, step=global_step)
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
Expand Down

0 comments on commit 53e9aac

Please sign in to comment.