Skip to content

Commit

Permalink
PAG variant for HunyuanDiT, PAG refactor (huggingface#8936)
Browse files Browse the repository at this point in the history
* copy hunyuandit pipeline

* pag variant of hunyuan dit

* add tests

* update docs

* make style

* make fix-copies

* Update src/diffusers/pipelines/pag/pag_utils.py

* remove incorrect copied from

* remove pag hunyuan attn procs to resolve conflicts

* add pag attn procs again

* new implementation for pag_utils

* revert pag changes

* add pag refactor back; update pixart sigma

* update pixart pag tests

* apply suggestions from review

Co-Authored-By: [email protected]

* make style

* update docs, fix tests

* fix tests

* fix test_components_function since list not accepted as valid __init__ param

* apply patch to fix broken tests

Co-Authored-By: Sayak Paul <[email protected]>

* make style

* fix hunyuan tests

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
a-r-r-o-w and sayakpaul authored Aug 5, 2024
1 parent e1d508a commit b7058d1
Show file tree
Hide file tree
Showing 16 changed files with 1,737 additions and 354 deletions.
20 changes: 19 additions & 1 deletion docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,29 @@ The abstract from the paper is:

*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*

PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.

- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
- Partial identifier as a RegEx: `down_blocks.2`, or `attn1`
- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]`

<Tip warning={true}>

Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.

</Tip>

## AnimateDiffPAGPipeline
[[autodoc]] AnimateDiffPAGPipeline
- all
- __call__

## HunyuanDiTPAGPipeline
[[autodoc]] HunyuanDiTPAGPipeline
- all
- __call__

## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
Expand Down Expand Up @@ -59,4 +77,4 @@ The abstract from the paper is:
## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
- all
- __call__
- __call__
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@
"CycleDiffusionPipeline",
"FluxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
Expand Down Expand Up @@ -675,6 +676,7 @@
CycleDiffusionPipeline,
FluxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
Expand Down
249 changes: 249 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,253 @@ def __call__(
return hidden_states


class PAGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)

# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states_org)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class PAGCFGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])

# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states_org)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
Expand Down Expand Up @@ -3468,4 +3715,6 @@ def __init__(self):
CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
]
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
_import_structure["pag"].extend(
[
"AnimateDiffPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
Expand Down Expand Up @@ -532,6 +533,7 @@
from .musicldm import MusicLDMPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
Expand Down Expand Up @@ -85,6 +86,7 @@
("stable-diffusion-3", StableDiffusion3Pipeline),
("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
Expand All @@ -41,6 +42,7 @@
else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
Expand Down
Loading

0 comments on commit b7058d1

Please sign in to comment.