diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 24c71d5c569d..18273746c283 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N ## Training Kontext [Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We -provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too. +provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too. -Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section. +**important** + +> [!NOTE] +> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below. +> To do this, execute the following steps in a new virtual environment: +> ``` +> git clone https://github.com/huggingface/diffusers +> cd diffusers +> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b +> pip install -e . +> ``` Below is an example training command: @@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \ Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not perform as expected. +Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets: + +* Condition image +* Target image +* Instruction + +[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training: + +```bash +accelerate launch train_dreambooth_lora_flux_kontext.py \ + --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \ + --output_dir="kontext-i2i" \ + --dataset_name="kontext-community/relighting" \ + --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --optimizer="adamw" \ + --use_8bit_adam \ + --cache_latents \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=200 \ + --max_train_steps=1000 \ + --rank=16\ + --seed="0" +``` + +More generally, when performing I2I fine-tuning, we expect you to: + +* Have a dataset `kontext-community/relighting` +* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training + ### Misc notes * By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it. @@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 ## Other notes -Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file +Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 9f97567b06b8..5bd9b8684d42 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -40,7 +40,7 @@ from torch.utils.data import Dataset from torch.utils.data.sampler import BatchSampler from torchvision import transforms -from torchvision.transforms.functional import crop +from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast @@ -62,11 +62,7 @@ free_memory, parse_buckets_string, ) -from diffusers.utils import ( - check_min_version, - convert_unet_state_dict_to_peft, - is_wandb_available, -) +from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -186,6 +182,7 @@ def log_validation( ) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) + pipeline_args_cp = pipeline_args.copy() # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None @@ -193,14 +190,16 @@ def log_validation( # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( - pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] - ) + prompt = pipeline_args_cp.pop("prompt") + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None) images = [] for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator + **pipeline_args_cp, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator, ).images[0] images.append(image) @@ -310,6 +309,12 @@ def parse_args(input_args=None): "default, the standard Image Dataset maps out 'file_name' " "to 'image'.", ) + parser.add_argument( + "--cond_image_column", + type=str, + default=None, + help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning", + ) parser.add_argument( "--caption_column", type=str, @@ -330,7 +335,6 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) parser.add_argument( @@ -351,6 +355,12 @@ def parse_args(input_args=None): default=None, help="A prompt that is used during validation to verify that the model is learning.", ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.", + ) parser.add_argument( "--num_validation_images", type=int, @@ -399,7 +409,7 @@ def parse_args(input_args=None): parser.add_argument( "--output_dir", type=str, - default="flux-dreambooth-lora", + default="flux-kontext-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") @@ -716,6 +726,8 @@ def parse_args(input_args=None): raise ValueError("You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") + if args.cond_image_column is not None: + raise ValueError("Prior preservation isn't supported with I2I training.") else: # logger is not available yet if args.class_data_dir is not None: @@ -723,6 +735,14 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if args.cond_image_column is not None: + assert args.image_column is not None + assert args.caption_column is not None + assert args.dataset_name is not None + assert not args.train_text_encoder + if args.validation_prompt is not None: + assert args.validation_image is None and os.path.exists(args.validation_image) + return args @@ -742,6 +762,7 @@ def __init__( repeats=1, center_crop=False, buckets=None, + args=None, ): self.center_crop = center_crop @@ -774,6 +795,10 @@ def __init__( column_names = dataset["train"].column_names # 6. Get the column names for input/target. + if args.cond_image_column is not None and args.cond_image_column not in column_names: + raise ValueError( + f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) if args.image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") @@ -783,7 +808,12 @@ def __init__( raise ValueError( f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - instance_images = dataset["train"][image_column] + instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))] + cond_images = None + cond_image_column = args.cond_image_column + if cond_image_column is not None: + cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))] + assert len(instance_images) == len(cond_images) if args.caption_column is None: logger.info( @@ -811,14 +841,23 @@ def __init__( self.custom_instance_prompts = None self.instance_images = [] - for img in instance_images: + self.cond_images = [] + for i, img in enumerate(instance_images): self.instance_images.extend(itertools.repeat(img, repeats)) + if args.dataset_name is not None and cond_images is not None: + self.cond_images.extend(itertools.repeat(cond_images[i], repeats)) self.pixel_values = [] - for image in self.instance_images: + self.cond_pixel_values = [] + for i, image in enumerate(self.instance_images): image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") + dest_image = None + if self.cond_images: + dest_image = exif_transpose(self.cond_images[i]) + if not dest_image.mode == "RGB": + dest_image = dest_image.convert("RGB") width, height = image.size @@ -828,25 +867,16 @@ def __init__( self.size = (target_height, target_width) # based on the bucket assignment, define the transformations - train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] + image, dest_image = self.paired_transform( + image, + dest_image=dest_image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, ) - image = train_resize(image) - if args.center_crop: - image = train_crop(image) - else: - y1, x1, h, w = train_crop.get_params(image, self.size) - image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - image = train_flip(image) - image = train_transforms(image) self.pixel_values.append((image, bucket_idx)) + if dest_image is not None: + self.cond_pixel_values.append((dest_image, bucket_idx)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -880,6 +910,9 @@ def __getitem__(self, index): instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] example["instance_images"] = instance_image example["bucket_idx"] = bucket_idx + if self.cond_pixel_values: + dest_image, _ = self.cond_pixel_values[index % self.num_instance_images] + example["cond_images"] = dest_image if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -902,6 +935,43 @@ def __getitem__(self, index): return example + def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + if dest_image is not None: + dest_image = resize(dest_image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + if dest_image is not None: + dest_image = crop(dest_image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + if dest_image is not None: + dest_image = TF.crop(dest_image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + if dest_image is not None: + dest_image = TF.hflip(dest_image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + if dest_image is not None: + dest_image = normalize(to_tensor(dest_image)) + + return (image, dest_image) if dest_image is not None else (image, None) + def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] @@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() batch = {"pixel_values": pixel_values, "prompts": prompts} + if any("cond_images" in example for example in examples): + cond_pixel_values = [example["cond_images"] for example in examples] + cond_pixel_values = torch.stack(cond_pixel_values) + cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() + batch.update({"cond_pixel_values": cond_pixel_values}) return batch @@ -1318,6 +1393,7 @@ def main(args): "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", + "proj_mlp", ] # now we will add new LoRA weights the transformer layers @@ -1534,7 +1610,10 @@ def load_model_hook(models, input_dir): buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, + args=args, ) + if args.cond_image_column is not None: + logger.info("I2I fine-tuning enabled.") batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) train_dataloader = torch.utils.data.DataLoader( train_dataset, @@ -1574,6 +1653,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + text_encoder_one.cpu(), text_encoder_two.cpu() del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two free_memory() @@ -1605,19 +1685,41 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + elif train_dataset.custom_instance_prompts and not args.train_text_encoder: + cached_text_embeddings = [] + for batch in tqdm(train_dataloader, desc="Embedding prompts"): + batch_prompts = batch["prompts"] + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + batch_prompts, text_encoders, tokenizers + ) + cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids)) + + if args.validation_prompt is None: + text_encoder_one.cpu(), text_encoder_two.cpu() + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() + vae_config_shift_factor = vae.config.shift_factor vae_config_scaling_factor = vae.config.scaling_factor vae_config_block_out_channels = vae.config.block_out_channels + has_image_input = args.cond_image_column is not None if args.cache_latents: latents_cache = [] + cond_latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( accelerator.device, non_blocking=True, dtype=weight_dtype ) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if has_image_input: + batch["cond_pixel_values"] = batch["cond_pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist) if args.validation_prompt is None: + vae.cpu() del vae free_memory() @@ -1678,7 +1780,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = "dreambooth-flux-dev-lora" + tracker_name = "dreambooth-flux-kontext-lora" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! @@ -1742,6 +1844,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigma = sigma.unsqueeze(-1) return sigma + has_guidance = unwrap_model(transformer).config.guidance_embeds for epoch in range(first_epoch, args.num_train_epochs): transformer.train() if args.train_text_encoder: @@ -1759,9 +1862,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( - prompts, text_encoders, tokenizers - ) + prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step] else: tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) tokens_two = tokenize_prompt( @@ -1794,16 +1895,29 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.cache_latents: if args.vae_encode_mode == "sample": model_input = latents_cache[step].sample() + if has_image_input: + cond_model_input = cond_latents_cache[step].sample() else: model_input = latents_cache[step].mode() + if has_image_input: + cond_model_input = cond_latents_cache[step].mode() else: pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + if has_image_input: + cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) if args.vae_encode_mode == "sample": model_input = vae.encode(pixel_values).latent_dist.sample() + if has_image_input: + cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample() else: model_input = vae.encode(pixel_values).latent_dist.mode() + if has_image_input: + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) + if has_image_input: + cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor + cond_model_input = cond_model_input.to(dtype=weight_dtype) vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) @@ -1814,6 +1928,17 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.device, weight_dtype, ) + if has_image_input: + cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids( + cond_model_input.shape[0], + cond_model_input.shape[2] // 2, + cond_model_input.shape[3] // 2, + accelerator.device, + weight_dtype, + ) + cond_latents_ids[..., 0] = 1 + latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0) + # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] @@ -1834,7 +1959,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - packed_noisy_model_input = FluxKontextPipeline._pack_latents( noisy_model_input, batch_size=model_input.shape[0], @@ -1842,13 +1966,22 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): height=model_input.shape[2], width=model_input.shape[3], ) + orig_inp_shape = packed_noisy_model_input.shape + if has_image_input: + packed_cond_input = FluxKontextPipeline._pack_latents( + cond_model_input, + batch_size=cond_model_input.shape[0], + num_channels_latents=cond_model_input.shape[1], + height=cond_model_input.shape[2], + width=cond_model_input.shape[3], + ) + packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1) - # handle guidance - if unwrap_model(transformer).config.guidance_embeds: + # Kontext always has guidance + guidance = None + if has_guidance: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) - else: - guidance = None # Predict the noise residual model_pred = transformer( @@ -1862,6 +1995,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=latent_image_ids, return_dict=False, )[0] + if has_image_input: + model_pred = model_pred[:, : orig_inp_shape[1]] model_pred = FluxKontextPipeline._unpack_latents( model_pred, height=model_input.shape[2] * vae_scale_factor, @@ -1970,6 +2105,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_dtype=weight_dtype, ) pipeline_args = {"prompt": args.validation_prompt} + if has_image_input and args.validation_image: + pipeline_args.update({"image": load_image(args.validation_image)}) images = log_validation( pipeline=pipeline, args=args, @@ -2030,6 +2167,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): images = [] if args.validation_prompt and args.num_validation_images > 0: pipeline_args = {"prompt": args.validation_prompt} + if has_image_input and args.validation_image: + pipeline_args.update({"image": load_image(args.validation_image)}) images = log_validation( pipeline=pipeline, args=args,