Skip to content

Commit

Permalink
Cleanup ControlnetXS (huggingface#7701)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
DN6 authored Apr 19, 2024
1 parent 90250d9 commit 3cfe187
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 35 deletions.
35 changes: 29 additions & 6 deletions src/diffusers/models/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, is_torch_version, logging
from ..utils.torch_utils import apply_freeu
from .attention_processor import Attention, AttentionProcessor
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .controlnet import ControlNetConditioningEmbedding
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
Expand Down Expand Up @@ -869,7 +876,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:

return processors

# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Expand Down Expand Up @@ -904,7 +911,23 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)

self.set_attn_processor(processor)

# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
Expand All @@ -929,7 +952,7 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)

# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
Expand All @@ -938,7 +961,7 @@ def disable_freeu(self):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)

# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
Expand All @@ -962,7 +985,7 @@ def fuse_qkv_projections(self):
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
Expand Down Expand Up @@ -462,7 +461,6 @@ def check_inputs(
prompt,
prompt_2,
image,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
Expand All @@ -474,12 +472,6 @@ def check_inputs(
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
Expand Down Expand Up @@ -749,7 +741,6 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -878,30 +869,13 @@ def __call__(
returned, otherwise a `tuple` is returned containing the output images.
"""

callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)

unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
image,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
Expand Down Expand Up @@ -1089,9 +1063,6 @@ def __call__(
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

# manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
Expand Down
32 changes: 32 additions & 0 deletions tests/pipelines/controlnet_xs/test_controlnetxs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@
enable_full_determinism()


def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()

return tensor


# Will be run via run_test_in_subprocess
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
error = None
Expand Down Expand Up @@ -299,6 +306,31 @@ def test_multi_vae(self):

assert out_vae_np.shape == out_np.shape

@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)

pipe.to("cpu")
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))

output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)

pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))

output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)


@slow
@require_torch_gpu
Expand Down

0 comments on commit 3cfe187

Please sign in to comment.