diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 502ee68a4d38..f19571dafb18 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -242,9 +242,12 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value @classmethod - def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True): + def from_transformer( + cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True + ): config = transformer.config config["num_layers"] = num_layers or config.num_layers + config["extra_conditioning_channels"] = num_extra_conditioning_channels controlnet = cls(**config) if load_weights_from_transformer: