Skip to content

Commit

Permalink
add PAG support for Stable Diffusion 3 (huggingface#8861)
Browse files Browse the repository at this point in the history
add pag sd3


---------

Co-authored-by: HyoungwonCho <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: crepejung00 <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Aryan <[email protected]>
Co-authored-by: Aryan <[email protected]>
  • Loading branch information
7 people authored Aug 6, 2024
1 parent 325a5de commit 926daa3
Show file tree
Hide file tree
Showing 9 changed files with 1,629 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- __call__


## StableDiffusion3PAGPipeline
[[autodoc]] StableDiffusion3PAGPipeline
- all
- __call__


## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
- all
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline",
Expand Down Expand Up @@ -741,6 +742,7 @@
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline,
Expand Down
320 changes: 320 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,326 @@ def __call__(
return hidden_states, encoder_hidden_states


class PAGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
) -> torch.FloatTensor:
residual = hidden_states

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# store the length of image patch sequences to create a mask that prevents interaction between patches
# similar to making the self-attention map an identity matrix
identity_block_size = hidden_states.shape[1]

# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)

################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]

# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)

# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)

# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)

inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)

# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

################## perturbed path ##################

batch_size = encoder_hidden_states_ptb.shape[0]

# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)

# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)

# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)

inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)

# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")

# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)

# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions

hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)

# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])

return hidden_states, encoder_hidden_states


class PAGCFGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

identity_block_size = hidden_states.shape[
1
] # patch embeddings width * height (correspond to self-attention map width or height)

# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])

(
encoder_hidden_states_uncond,
encoder_hidden_states_org,
encoder_hidden_states_ptb,
) = encoder_hidden_states.chunk(3)
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])

################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]

# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)

# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)

# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)

inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)

# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

################## perturbed path ##################

batch_size = encoder_hidden_states_ptb.shape[0]

# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)

# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)

# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)

inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)

# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")

# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)

# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions

hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)

# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])

return hidden_states, encoder_hidden_states


class FusedJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
[
"AnimateDiffPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
Expand Down Expand Up @@ -540,6 +541,7 @@
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
Expand Down Expand Up @@ -84,6 +85,7 @@
("stable-diffusion", StableDiffusionPipeline),
("stable-diffusion-xl", StableDiffusionXLPipeline),
("stable-diffusion-3", StableDiffusion3Pipeline),
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline),
("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline),
Expand Down
Loading

0 comments on commit 926daa3

Please sign in to comment.