diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 5625f9755b19..fd6c639a7cdf 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -30,6 +30,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { "UNet2DConditionModel": _maybe_expand_lora_scales, + "UNetMotionModel": _maybe_expand_lora_scales, "SD3Transformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 196f947d599b..c8ea0ecc3feb 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -19,7 +19,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import logging from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -231,7 +231,7 @@ def forward(self, sample): pass -class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): +class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): r""" A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output.