Skip to content

Commit

Permalink
Fix the issue on sd3 dreambooth w./w.t. lora training (huggingface#9419)
Browse files Browse the repository at this point in the history
* Fix dtype error

* [bugfix] Fixed the issue on sd3 dreambooth training

* [bugfix] Fixed the issue on sd3 dreambooth training

---------

Co-authored-by: 蒋硕 <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Sep 14, 2024
1 parent 48e3635 commit e2ead7c
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 6 deletions.
5 changes: 4 additions & 1 deletion examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1717,6 +1718,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
Expand Down Expand Up @@ -1761,6 +1763,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)

if args.push_to_hub:
Expand Down
5 changes: 4 additions & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
Expand All @@ -141,7 +142,7 @@ def log_validation(

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1360,6 +1361,7 @@ def compute_text_embeddings(prompt):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)

# Save the lora layers
Expand Down Expand Up @@ -1402,6 +1404,7 @@ def compute_text_embeddings(prompt):
pipeline_args,
epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)

if args.push_to_hub:
Expand Down
5 changes: 4 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1785,6 +1786,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
Expand Down Expand Up @@ -1832,6 +1834,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)

if args.push_to_hub:
Expand Down
5 changes: 4 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1788,6 +1789,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
objs = []
if not args.train_text_encoder:
Expand Down Expand Up @@ -1840,6 +1842,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)

if args.push_to_hub:
Expand Down
5 changes: 4 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
Expand All @@ -201,7 +202,7 @@ def log_validation(

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1890,6 +1891,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator,
pipeline_args,
epoch,
torch_dtype=weight_dtype,
)

# Save the lora layers
Expand Down Expand Up @@ -1955,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args,
epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)

if args.push_to_hub:
Expand Down
5 changes: 4 additions & 1 deletion examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,14 @@ def log_validation(
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)

# run inference
Expand Down Expand Up @@ -1725,6 +1726,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
Expand Down Expand Up @@ -1775,6 +1777,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)

if args.push_to_hub:
Expand Down

0 comments on commit e2ead7c

Please sign in to comment.