Skip to content

Commit

Permalink
fix bugs when using deepspeed in sdxl (huggingface#7917)
Browse files Browse the repository at this point in the history
fix bugs when using deepspeed

Co-authored-by: mhh001 <[email protected]>
  • Loading branch information
HelloWorldBeginner and mhh001 authored May 11, 2024
1 parent be4afa0 commit 0267c52
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
Expand Down Expand Up @@ -742,7 +742,8 @@ def save_model_hook(models, weights, output_dir):
model.save_pretrained(os.path.join(output_dir, "unet"))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

def load_model_hook(models, input_dir):
if args.use_ema:
Expand Down Expand Up @@ -914,7 +915,7 @@ def preprocess_train(examples):
train_dataset_with_vae = train_dataset.map(
compute_vae_encodings_fn,
batched=True,
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
batch_size=args.train_batch_size,
new_fingerprint=new_fingerprint_for_vae,
)
precomputed_dataset = concatenate_datasets(
Expand Down Expand Up @@ -1160,7 +1161,8 @@ def compute_time_ids(original_size, crops_coords_top_left):
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0

if accelerator.is_main_process:
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Expand Down

0 comments on commit 0267c52

Please sign in to comment.