From f6df22447c79f8ba268a44067e90367e4e809f2e Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 7 Aug 2024 11:12:30 +0530 Subject: [PATCH] [feat] allow sparsectrl to be loaded from single file (#9073) * allow sparsectrl to be loaded with single file * update --------- Co-authored-by: Dhruv Nair --- src/diffusers/loaders/single_file_model.py | 3 +++ src/diffusers/loaders/single_file_utils.py | 14 ++++++++++++-- src/diffusers/models/controlnet_sparsectrl.py | 5 ++++- src/diffusers/models/unets/unet_motion_model.py | 4 ++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 23d0b0ab2e7d..3fe1abfbead5 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -75,6 +75,9 @@ "MotionAdapter": { "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, }, + "SparseControlNetModel": { + "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, + }, "FluxTransformer2DModel": { "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 0dce9d5c7aff..9c2a2cbf2942 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -74,9 +74,11 @@ "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", "stable_cascade_stage_c": "clip_txt_mapper.weight", "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", - "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe", + "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", + "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", + "animatediff_rgb": "controlnet_cond_embedding.weight", "flux": "double_blocks.0.img_attn.norm.key_norm.scale", } @@ -111,6 +113,8 @@ "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, + "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, + "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, } @@ -494,7 +498,13 @@ def infer_diffusers_model_type(checkpoint): model_type = "sd3" elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: - if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: + if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: + model_type = "animatediff_scribble" + + elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint: + model_type = "animatediff_rgb" + + elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: model_type = "animatediff_v2" elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320: diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index cb577e33c670..e91551c70953 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -20,6 +20,7 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalModelMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -92,7 +93,7 @@ def forward(self, conditioning: torch.Tensor) -> torch.Tensor: return embedding -class SparseControlNetModel(ModelMixin, ConfigMixin): +class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933). @@ -314,6 +315,7 @@ def __init__( temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + temporal_double_self_attention=False, ) elif down_block_type == "DownBlockMotion": down_block = DownBlockMotion( @@ -331,6 +333,7 @@ def __init__( temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + temporal_double_self_attention=False, ) else: raise ValueError( diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 7a9b7f2d5afe..73c9c70c4a11 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -233,6 +233,7 @@ def __init__( temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + temporal_double_self_attention: bool = True, ): super().__init__() resnets = [] @@ -282,6 +283,7 @@ def __init__( positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads[i], + double_self_attention=temporal_double_self_attention, ) ) @@ -385,6 +387,7 @@ def __init__( temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, + temporal_double_self_attention: bool = True, ): super().__init__() resnets = [] @@ -466,6 +469,7 @@ def __init__( positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, + double_self_attention=temporal_double_self_attention, ) )