Skip to content

Commit

Permalink
Merge branch 'main' into fix-wy-onediffx-multi-lora-bad-quality
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold authored Jul 12, 2024
2 parents 325cbb4 + f498be2 commit 0f0f102
Showing 1 changed file with 58 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,58 @@


class TemporalTransformer3DModel_OF(TemporalTransformer3DModel_OF_CLS):
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options=None):
def get_cameractrl_effect(self, hidden_states: torch.Tensor) :
# if no raw camera_Ctrl, return None
if self.raw_cameractrl_effect is None:
return 1.0
# if raw_cameractrl is not a Tensor, return it (should be a float)
if type(self.raw_cameractrl_effect) != torch.Tensor:
return self.raw_cameractrl_effect
shape = hidden_states.shape
batch, channel, height, width = shape
# if temp_cameractrl already calculated, return it
if self.temp_cameractrl_effect != None:
# check if hidden_states batch matches
if batch == self.prev_cameractrl_hidden_states_batch:
if self.sub_idxs is not None:
return self.temp_cameractrl_effect[:, self.sub_idxs, :]
return self.temp_cameractrl_effect
# if does not match, reset cached temp_cameractrl and recalculate it
del self.temp_cameractrl_effect
self.temp_cameractrl_effect = None
# otherwise, calculate temp_cameractrl
self.prev_cameractrl_hidden_states_batch = batch
mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width))
mask = repeat_to_batch_size(mask, self.full_length)
# if mask not the same amount length as full length, make it match
if self.full_length != mask.shape[0]:
mask = broadcast_image_to(mask, self.full_length, 1)
# reshape mask to attention K shape (h*w, latent_count, 1)
batch, channel, height, width = mask.shape
# first, perform same operations as on hidden_states,
# turning (b, c, h, w) -> (b, h*w, c)
mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel)
# then, make it the same shape as attention's k, (h*w, b, c)
mask = mask.permute(1, 0, 2)
# make masks match the expected length of h*w
batched_number = shape[0] // self.video_length
if batched_number > 1:
mask = torch.cat([mask] * batched_number, dim=0)
# cache mask and set to proper device
self.temp_cameractrl_effect = mask
# move temp_cameractrl to proper dtype + device
self.temp_cameractrl_effect = self.temp_cameractrl_effect.to(dtype=hidden_states.dtype, device=hidden_states.device)
# return subset of masks, if needed
if self.sub_idxs is not None:
return self.temp_cameractrl_effect[:, self.sub_idxs, :]
return self.temp_cameractrl_effect


def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options=None, mm_kwargs: dict[str]=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
cameractrl_effect = self.get_cameractrl_effect(hidden_states)

scale_mask = self.get_scale_mask(hidden_states)
# add some casts for fp8 purposes - does not affect speed otherwise
hidden_states = self.norm(hidden_states).to(hidden_states.dtype)
Expand All @@ -41,7 +90,9 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
attention_mask=attention_mask,
video_length=self.video_length,
scale_mask=scale_mask,
view_options=view_options
cameractrl_effect=cameractrl_effect,
view_options=view_options,
mm_kwargs=mm_kwargs
)

# output
Expand All @@ -67,6 +118,8 @@ def forward(
attention_mask=None,
video_length=None,
scale_mask=None,
cameractrl_effect= 1.0,
mm_kwargs: dict[str]={},
):
if self.attention_mode != "Temporal":
raise NotImplementedError
Expand All @@ -89,6 +142,9 @@ def forward(
if encoder_hidden_states is not None
else encoder_hidden_states
)
if self.camera_feature_enabled and self.qkv_merge is not None and mm_kwargs is not None and "camera_feature" in mm_kwargs:
camera_feature: torch.Tensor = mm_kwargs["camera_feature"]
hidden_states = (self.qkv_merge(hidden_states + camera_feature) + hidden_states) * cameractrl_effect + hidden_states * (1. - cameractrl_effect)

# hidden_states = super().forward(
# hidden_states,
Expand Down

0 comments on commit 0f0f102

Please sign in to comment.