Skip to content

Commit

Permalink
[LoRA] fix: animate diff lora stuff. (huggingface#8995)
Browse files Browse the repository at this point in the history
* fix: animate diff lora stuff.

* fix scaling function for UNetMotionModel

* emoty
  • Loading branch information
sayakpaul authored Jul 30, 2024
1 parent f240a93 commit 8c4856c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
}

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.utils.checkpoint

from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import logging
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -231,7 +231,7 @@ def forward(self, sample):
pass


class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
r"""
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
sample shaped output.
Expand Down

0 comments on commit 8c4856c

Please sign in to comment.