Skip to content

Add --lora_alpha and metadata handling to train_dreambooth_lora_sd3.py #11783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import sys
import tempfile

import safetensors

from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY

sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
Expand Down Expand Up @@ -207,6 +209,46 @@ def test_dreambooth_lora_layer(self):
starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_sd3_with_metadata(self):
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--lora_alpha={lora_alpha}
--rank={rank}
--checkpointing_steps=2
--max_sequence_length 166
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)

state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))

# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}

metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)

loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)


def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand Down
15 changes: 14 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_set_state_dict_into_text_encoder,
cast_training_params,
compute_density_for_timestep_sampling,
Expand Down Expand Up @@ -321,6 +322,12 @@ def parse_args(input_args=None):
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument(
"--instance_prompt",
type=str,
Expand Down Expand Up @@ -1266,7 +1273,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
Expand Down Expand Up @@ -1295,13 +1302,15 @@ def save_model_hook(models, weights, output_dir):
transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
modules_to_save = {}

for model in models:
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
if args.upcast_before_saving:
model = model.to(torch.float32)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
elif args.train_text_encoder and isinstance(
unwrap_model(model), type(unwrap_model(text_encoder_one))
): # or text_encoder_two
Expand All @@ -1324,6 +1333,7 @@ def save_model_hook(models, weights, output_dir):
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)

def load_model_hook(models, input_dir):
Expand Down Expand Up @@ -1925,10 +1935,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
modules_to_save = {}
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
modules_to_save["transformer"] = transformer
transformer_lora_layers = get_peft_model_state_dict(transformer)

if args.train_text_encoder:
Expand All @@ -1945,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
**_collate_lora_metadata(modules_to_save),
)

# Final inference
Expand Down