Skip to content

Commit

Permalink
Handling mixed precision for dreambooth flux lora training (huggingfa…
Browse files Browse the repository at this point in the history
…ce#9565)

Handling mixed precision and add unwarp

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2024
1 parent 7ffbc25 commit 3deed72
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1706,7 +1706,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# handle guidance
if transformer.config.guidance_embeds:
if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
Expand Down Expand Up @@ -1819,6 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand Down

0 comments on commit 3deed72

Please sign in to comment.