Skip to content

Commit

Permalink
fix vae dtype when accelerate config using --mixed_precision="fp16" (h…
Browse files Browse the repository at this point in the history
…uggingface#9601)

* fix vae dtype when accelerate config using --mixed_precision="fp16"

* Add param for upcast vae
  • Loading branch information
xduzhangjiayu authored Oct 7, 2024
1 parent 31010ec commit 2cb383f
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ def parse_args(input_args=None):
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--upcast_vae",
action="store_true",
help="Whether or not to upcast vae to fp32",
)
parser.add_argument(
"--learning_rate",
type=float,
Expand Down Expand Up @@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir):
weight_dtype = torch.bfloat16

# Move vae, transformer and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=torch.float32)
if args.upcast_vae:
vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
Expand Down

0 comments on commit 2cb383f

Please sign in to comment.