Skip to content

Commit

Permalink
[refactor] apply qk norm in attention processors (huggingface#9071)
Browse files Browse the repository at this point in the history
* apply qk norm in attention processors

* revert attention processor

* qk-norm in only attention proc 2.0 and fused variant
  • Loading branch information
a-r-r-o-w authored Aug 4, 2024
1 parent 4f0d01d commit 2b76099
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,11 @@ def __call__(
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
Expand Down Expand Up @@ -2314,6 +2319,11 @@ def __call__(
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
Expand Down

0 comments on commit 2b76099

Please sign in to comment.