From 7872c0698b80939274f7e61279edf5aff547c584 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 23 Jun 2025 13:32:18 +0530 Subject: [PATCH 1/2] add an offload utility that can be used as a context manager. --- .github/workflows/pr_tests_gpu.yml | 1 + .../train_dreambooth_lora_hidream.py | 62 ++++++++----------- src/diffusers/training_utils.py | 29 +++++++++ 3 files changed, 56 insertions(+), 36 deletions(-) diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index 87d51773888e..48d9e7553885 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -13,6 +13,7 @@ on: - "src/diffusers/loaders/peft.py" - "tests/pipelines/test_pipelines_common.py" - "tests/models/test_modeling_common.py" + - "examples/**/*.py" workflow_dispatch: concurrency: diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index a1337e8dbaa4..9b38151415c8 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -58,6 +58,7 @@ compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory, + offload_models, ) from diffusers.utils import ( check_min_version, @@ -1364,43 +1365,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. if not train_dataset.custom_instance_prompts: - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) - ( - instance_prompt_hidden_states_t5, - instance_prompt_hidden_states_llama3, - instance_pooled_prompt_embeds, - _, - _, - _, - ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to("cpu") + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + ( + instance_prompt_hidden_states_t5, + instance_prompt_hidden_states_llama3, + instance_pooled_prompt_embeds, + _, + _, + _, + ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) - (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( - compute_text_embeddings(args.class_prompt, text_encoding_pipeline) - ) - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to("cpu") + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( + compute_text_embeddings(args.class_prompt, text_encoding_pipeline) + ) validation_embeddings = {} if args.validation_prompt is not None: - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) - ( - validation_embeddings["prompt_embeds_t5"], - validation_embeddings["prompt_embeds_llama3"], - validation_embeddings["pooled_prompt_embeds"], - validation_embeddings["negative_prompt_embeds_t5"], - validation_embeddings["negative_prompt_embeds_llama3"], - validation_embeddings["negative_pooled_prompt_embeds"], - ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to("cpu") + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + ( + validation_embeddings["prompt_embeds_t5"], + validation_embeddings["prompt_embeds_llama3"], + validation_embeddings["pooled_prompt_embeds"], + validation_embeddings["negative_prompt_embeds_t5"], + validation_embeddings["negative_prompt_embeds_llama3"], + validation_embeddings["negative_pooled_prompt_embeds"], + ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1581,12 +1573,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.cache_latents: model_input = latents_cache[step].sample() else: - if args.offload: - vae = vae.to(accelerator.device) - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() - if args.offload: - vae = vae.to("cpu") + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index bc30411d8726..a92b48090bba 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -3,12 +3,14 @@ import gc import math import random +from contextlib import contextmanager from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch from .models import UNet2DConditionModel +from .pipelines import DiffusionPipeline from .schedulers import SchedulerMixin from .utils import ( convert_state_dict_to_diffusers, @@ -316,6 +318,33 @@ def free_memory(): torch.xpu.empty_cache() +@contextmanager +def offload_models(*modules: torch.nn.Module | DiffusionPipeline, device: str | torch.device, offload: bool = True): + """ + Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original + device on exit. + """ + if offload: + is_model = not any(isinstance(m, DiffusionPipeline) for m in modules) + # record where each module was + if is_model: + original_devices = [next(m.parameters()).device for m in modules] + else: + assert len(modules) == 1 + original_devices = modules[0].device + # move to target device + for m in modules: + m.to(device) + + try: + yield + finally: + if offload: + # move back to original devices + for m, orig_dev in zip(modules, original_devices): + m.to(orig_dev) + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ From 6beace41e8cafdd51c25f4489f9e656a5332fd3b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 23 Jun 2025 13:36:17 +0530 Subject: [PATCH 2/2] update --- src/diffusers/training_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index a92b48090bba..e0c9aebf65bd 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -319,7 +319,9 @@ def free_memory(): @contextmanager -def offload_models(*modules: torch.nn.Module | DiffusionPipeline, device: str | torch.device, offload: bool = True): +def offload_models( + *modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True +): """ Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original device on exit.