Skip to content

Commit

Permalink
[feat] allow sparsectrl to be loaded from single file (huggingface#9073)
Browse files Browse the repository at this point in the history
* allow sparsectrl to be loaded with single file

* update

---------

Co-authored-by: Dhruv Nair <[email protected]>
  • Loading branch information
a-r-r-o-w and DN6 authored Aug 7, 2024
1 parent 9b5180c commit f6df224
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
"MotionAdapter": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"SparseControlNetModel": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"FluxTransformer2DModel": {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
Expand Down
14 changes: 12 additions & 2 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
}

Expand Down Expand Up @@ -111,6 +113,8 @@
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
}
Expand Down Expand Up @@ -494,7 +498,13 @@ def infer_diffusers_model_type(checkpoint):
model_type = "sd3"

elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
model_type = "animatediff_scribble"

elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
model_type = "animatediff_rgb"

elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
model_type = "animatediff_v2"

elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/models/controlnet_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.nn import functional as F

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalModelMixin
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -92,7 +93,7 @@ def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
return embedding


class SparseControlNetModel(ModelMixin, ConfigMixin):
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
Models](https://arxiv.org/abs/2311.16933).
Expand Down Expand Up @@ -314,6 +315,7 @@ def __init__(
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
temporal_double_self_attention=False,
)
elif down_block_type == "DownBlockMotion":
down_block = DownBlockMotion(
Expand All @@ -331,6 +333,7 @@ def __init__(
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
temporal_double_self_attention=False,
)
else:
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_double_self_attention: bool = True,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -282,6 +283,7 @@ def __init__(
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads[i],
double_self_attention=temporal_double_self_attention,
)
)

Expand Down Expand Up @@ -385,6 +387,7 @@ def __init__(
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
temporal_double_self_attention: bool = True,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -466,6 +469,7 @@ def __init__(
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
double_self_attention=temporal_double_self_attention,
)
)

Expand Down

0 comments on commit f6df224

Please sign in to comment.