Skip to content

Commit

Permalink
Merge branch 'main' of github.com:huggingface/diffusers
Browse files Browse the repository at this point in the history
  • Loading branch information
tolgacangoz committed Aug 12, 2024
2 parents 5041d40 + 413ca29 commit 6b15dc5
Show file tree
Hide file tree
Showing 10 changed files with 574 additions and 108 deletions.
8 changes: 4 additions & 4 deletions examples/dreambooth/README_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](documentation/quickstart/FLUX.md)
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)

> [!NOTE]
Expand Down Expand Up @@ -96,7 +96,7 @@ accelerate launch train_dreambooth_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
Expand Down Expand Up @@ -140,7 +140,7 @@ accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
Expand Down Expand Up @@ -175,7 +175,7 @@ accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--mixed_precision="bf16" \
--train_text_encoder\
--instance_prompt="a photo of sks dog" \
--resolution=512 \
Expand Down
8 changes: 8 additions & 0 deletions examples/dreambooth/requirements_flux.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
accelerate>=0.31.0
torchvision
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft>=0.11.1
sentencepiece
203 changes: 203 additions & 0 deletions examples/dreambooth/test_dreambooth_flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import shutil
import sys
import tempfile

from diffusers import DiffusionPipeline, FluxTransformer2DModel


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)


class DreamBoothFlux(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_flux.py"

def test_dreambooth(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

def test_dreambooth_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4

initial_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()

run_command(self._launch_args + initial_run_args)

# check can run the original fully trained output pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir)
pipe(self.instance_prompt, num_inference_steps=1)

# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))

# check can run an intermediate checkpoint
transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
pipe(self.instance_prompt, num_inference_steps=1)

# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))

# Run training script for 7 total steps resuming from checkpoint 4

resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
""".split()

run_command(self._launch_args + resume_run_args)

# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir)
pipe(self.instance_prompt, num_inference_steps=1)

# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))

# check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))

def test_dreambooth_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)

def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)

resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
Loading

0 comments on commit 6b15dc5

Please sign in to comment.