Skip to content

Commit

Permalink
PAG variant for AnimateDiff (huggingface#8789)
Browse files Browse the repository at this point in the history
* add animatediff pag pipeline

* remove unnecessary print

* make fix-copies

* fix ip-adapter bug

* update docs

* add fast tests and fix bugs

* update

* update

* address review comments

* update ip adapter single test expected slice

* implement test_from_pipe_consistent_config; fix expected slice values

* LoraLoaderMixin->StableDiffusionLoraLoaderMixin; add latest freeinit test
  • Loading branch information
a-r-r-o-w authored Aug 1, 2024
1 parent ea1b4ea commit 05b706c
Show file tree
Hide file tree
Showing 8 changed files with 1,395 additions and 14 deletions.
5 changes: 5 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ The abstract from the paper is:

*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*

## AnimateDiffPAGPipeline
[[autodoc]] AnimateDiffPAGPipeline
- all
- __call__

## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffControlNetPipeline",
"AnimateDiffPAGPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
Expand Down Expand Up @@ -654,6 +655,7 @@
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffControlNetPipeline,
AnimateDiffPAGPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
)
_import_structure["pag"].extend(
[
"AnimateDiffPAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
Expand Down Expand Up @@ -527,6 +528,7 @@
)
from .musicldm import MusicLDMPipeline
from .pag import (
AnimateDiffPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
Expand All @@ -40,6 +41,7 @@
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
Expand Down
45 changes: 31 additions & 14 deletions src/diffusers/pipelines/pag/pag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _check_input_pag_applied_layer(layer):
Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats:
"{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type`
can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be
in the format of "attentions_{j}".
in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}"
"""

layer_splits = layer.split(".")
Expand All @@ -52,8 +52,11 @@ def _check_input_pag_applied_layer(layer):
raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'")

if len(layer_splits) == 3:
if not layer_splits[2].startswith("attentions_"):
raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'")
layer_2 = layer_splits[2]
if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"):
raise ValueError(
f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'"
)

def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
Expand All @@ -72,33 +75,46 @@ def is_self_attn(module_name):

def get_block_type(module_name):
r"""
Get the block type from the module name. can be "down", "mid", "up".
Get the block type from the module name. Can be "down", "mid", "up".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down"
# down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down"
return module_name.split(".")[0].split("_")[0]

def get_block_index(module_name):
r"""
Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g.
mid_block) and index is ommited from the name, it will be "block_0".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0"
if "attentions" in module_name.split(".")[1]:
module_name_splits = module_name.split(".")
block_index = module_name_splits[1]
if "attentions" in block_index or "motion_modules" in block_index:
return "block_0"
else:
return f"block_{module_name.split('.')[1]}"
return f"block_{block_index}"

def get_attn_index(module_name):
r"""
Get the attention index from the module name. can be "attentions_0", "attentions_1", ...
Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0",
"motion_modules_1", ...
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
if "attentions" in module_name.split(".")[2]:
return f"attentions_{module_name.split('.')[3]}"
elif "attentions" in module_name.split(".")[1]:
return f"attentions_{module_name.split('.')[2]}"
# down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
# mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
module_name_split = module_name.split(".")
mid_name = module_name_split[1]
down_name = module_name_split[2]
if "attentions" in down_name:
return f"attentions_{module_name_split[3]}"
if "attentions" in mid_name:
return f"attentions_{module_name_split[2]}"
if "motion_modules" in down_name:
return f"motion_modules_{module_name_split[3]}"
if "motion_modules" in mid_name:
return f"motion_modules_{module_name_split[2]}"

for pag_layer_input in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the unet model
Expand All @@ -114,7 +130,7 @@ def get_attn_index(module_name):
target_modules.append(module)

elif len(pag_layer_input_splits) == 2:
# when the layer inpput contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
# when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
for name, module in self.unet.named_modules():
Expand All @@ -126,7 +142,8 @@ def get_attn_index(module_name):
target_modules.append(module)

elif len(pag_layer_input_splits) == 3:
# when the layer input contains block_type, block_index and attention_index. e.g. "down.blocks_1.attentions_1"
# when the layer input contains block_type, block_index and attention_index.
# e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
attn_index = pag_layer_input_splits[2]
Expand Down
Loading

0 comments on commit 05b706c

Please sign in to comment.