Skip to content

Commit

Permalink
Add Learned PE selection for Auraflow (huggingface#9182)
Browse files Browse the repository at this point in the history
* add pe

* Update src/diffusers/models/transformers/auraflow_transformer_2d.py

Co-authored-by: Sayak Paul <[email protected]>

* Update src/diffusers/models/transformers/auraflow_transformer_2d.py

* beauty

* retrigger ci.

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
cloneofsimo and sayakpaul authored Aug 15, 2024
1 parent 0c1e63b commit 1a92bc0
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ def __init__(
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size

def pe_selection_index_based_on_dim(self, h, w):
# select subset of positional embedding based on H, W, where H, W is size of latent
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
starth = h_max // 2 - h_p // 2
endh = starth + h_p
startw = w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
return original_pe_indexes.flatten()

def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
Expand All @@ -80,7 +95,8 @@ def forward(self, latent):
)
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
latent = self.proj(latent)
return latent + self.pos_embed
pe_index = self.pe_selection_index_based_on_dim(height, width)
return latent + self.pos_embed[:, pe_index]


# Taken from the original Aura flow inference code.
Expand Down

0 comments on commit 1a92bc0

Please sign in to comment.