From d6e4a1a5a4d3217ffbda6a7ba1e51218606309a7 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Sat, 21 Jun 2025 17:53:09 +0000 Subject: [PATCH 1/4] update the test script. --- .../dreambooth/train_dreambooth_lora_hidream.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index a1337e8dbaa4..47842e73df3b 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -54,6 +54,7 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( + _collate_lora_metadata, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -420,6 +421,13 @@ def parse_args(input_args=None): parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument( "--with_prior_preservation", default=False, @@ -1163,7 +1171,7 @@ def main(args): # now we will add new LoRA weights the transformer layers transformer_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, + lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, init_lora_weights="gaussian", target_modules=target_modules, @@ -1180,10 +1188,12 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None + modules_to_save = {} for model in models: if isinstance(unwrap_model(model), type(unwrap_model(transformer))): model = unwrap_model(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1194,6 +1204,7 @@ def save_model_hook(models, weights, output_dir): HiDreamImagePipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), ) def load_model_hook(models, input_dir): @@ -1496,6 +1507,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: + modules_to_save = {} tracker_name = "dreambooth-hidream-lora" accelerator.init_trackers(tracker_name, config=vars(args)) @@ -1737,6 +1749,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer HiDreamImagePipeline.save_lora_weights( save_directory=args.output_dir, From 0dc1cdd408e43f955d272f9ac3ef8f5f66b3afdf Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Sat, 21 Jun 2025 17:57:53 +0000 Subject: [PATCH 2/4] update the test case file. --- .../test_dreambooth_lora_hidream.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora_hidream.py b/examples/dreambooth/test_dreambooth_lora_hidream.py index df4c70e2e86f..5347b1d194bc 100644 --- a/examples/dreambooth/test_dreambooth_lora_hidream.py +++ b/examples/dreambooth/test_dreambooth_lora_hidream.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import sys @@ -20,6 +21,8 @@ import safetensors +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + sys.path.append("..") from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 @@ -175,6 +178,49 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self): {"checkpoint-4", "checkpoint-6"}, ) + def test_dreambooth_lora_with_metadata(self): + # Use a `lora_alpha` that is different from `rank`. + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_alpha={lora_alpha} + --rank={rank} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) + + def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From 8d4ad6527a21aeb46213ba884925d225ba22e09b Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Mon, 23 Jun 2025 09:06:44 +0000 Subject: [PATCH 3/4] make style. --- examples/dreambooth/test_dreambooth_lora_hidream.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/test_dreambooth_lora_hidream.py b/examples/dreambooth/test_dreambooth_lora_hidream.py index 5347b1d194bc..ca992a236c2d 100644 --- a/examples/dreambooth/test_dreambooth_lora_hidream.py +++ b/examples/dreambooth/test_dreambooth_lora_hidream.py @@ -220,7 +220,6 @@ def test_dreambooth_lora_with_metadata(self): loaded_lora_rank = raw["transformer.r"] self.assertTrue(loaded_lora_rank == rank) - def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From ee2c2200d18c57b39b859991eabe264691c91616 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Mon, 30 Jun 2025 15:40:45 +0000 Subject: [PATCH 4/4] remove not uswed param --- examples/dreambooth/test_dreambooth_lora_hidream.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/test_dreambooth_lora_hidream.py b/examples/dreambooth/test_dreambooth_lora_hidream.py index ca992a236c2d..d5f025423069 100644 --- a/examples/dreambooth/test_dreambooth_lora_hidream.py +++ b/examples/dreambooth/test_dreambooth_lora_hidream.py @@ -187,7 +187,6 @@ def test_dreambooth_lora_with_metadata(self): {self.script_path} --pretrained_model_name_or_path {self.pretrained_model_name_or_path} --instance_data_dir {self.instance_data_dir} - --instance_prompt {self.instance_prompt} --resolution 64 --train_batch_size 1 --gradient_accumulation_steps 1