Skip to content

Commit

Permalink
[LoRA] pop the LoRA scale so that it doesn't get propagated to the we…
Browse files Browse the repository at this point in the history
…eds (huggingface#7338)

* pop scale from the top-level unet instead of getting it.

* improve readability.

* Apply suggestions from code review

Co-authored-by: YiYi Xu <[email protected]>

* fix a little bit.

---------

Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
sayakpaul and yiyixuxu authored Mar 19, 2024
1 parent 85f9d92 commit ce9825b
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,25 +1081,15 @@ def forward(
A tuple of tensors that if specified are added to the residuals of down unet blocks.
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
Expand Down Expand Up @@ -1185,7 +1175,13 @@ def forward(
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}

# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None:
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
Expand Down

0 comments on commit ce9825b

Please sign in to comment.