From 1a92bc05a7c22d600ecadc775b31f4d8b0be27b9 Mon Sep 17 00:00:00 2001 From: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com> Date: Thu, 15 Aug 2024 20:00:24 +0900 Subject: [PATCH] Add Learned PE selection for Auraflow (#9182) * add pe * Update src/diffusers/models/transformers/auraflow_transformer_2d.py Co-authored-by: Sayak Paul * Update src/diffusers/models/transformers/auraflow_transformer_2d.py * beauty * retrigger ci. --------- Co-authored-by: Sayak Paul --- .../transformers/auraflow_transformer_2d.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index f685e690cf81..0b20c4818a02 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -68,6 +68,21 @@ def __init__( self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size + def pe_selection_index_based_on_dim(self, h, w): + # select subset of positional embedding based on H, W, where H, W is size of latent + # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected + # because original input are in flattened format, we have to flatten this 2d grid as well. + h_p, w_p = h // self.patch_size, w // self.patch_size + original_pe_indexes = torch.arange(self.pos_embed.shape[1]) + h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5) + original_pe_indexes = original_pe_indexes.view(h_max, w_max) + starth = h_max // 2 - h_p // 2 + endh = starth + h_p + startw = w_max // 2 - w_p // 2 + endw = startw + w_p + original_pe_indexes = original_pe_indexes[starth:endh, startw:endw] + return original_pe_indexes.flatten() + def forward(self, latent): batch_size, num_channels, height, width = latent.size() latent = latent.view( @@ -80,7 +95,8 @@ def forward(self, latent): ) latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) latent = self.proj(latent) - return latent + self.pos_embed + pe_index = self.pe_selection_index_based_on_dim(height, width) + return latent + self.pos_embed[:, pe_index] # Taken from the original Aura flow inference code.