From f28a8c257afe8eeb16b4deb973c6b1829f6aea59 Mon Sep 17 00:00:00 2001 From: captainzz <73270275+xduzhangjiayu@users.noreply.github.com> Date: Tue, 10 Sep 2024 01:51:48 +0800 Subject: [PATCH] fix from_transformer() with extra conditioning channels (#9364) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix from_transformer() with extra conditioning channels * style fix --------- Co-authored-by: YiYi Xu Co-authored-by: Álvaro Somoza --- src/diffusers/models/controlnet_sd3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 502ee68a4d38..f19571dafb18 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -242,9 +242,12 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value @classmethod - def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True): + def from_transformer( + cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True + ): config = transformer.config config["num_layers"] = num_layers or config.num_layers + config["extra_conditioning_channels"] = num_extra_conditioning_channels controlnet = cls(**config) if load_weights_from_transformer: