Skip to content

StableDiffusionXLControlNetUnionInpaintPipeline requires a float for controlnet_conditioning_scale & control_guidance_end #11828

Open
@ewwwgiddings

Description

@ewwwgiddings

Describe the bug

Summary

StableDiffusionXLControlNetUnionInpaintPipeline requires a float for controlnet_conditioning_scale & control_guidance_end when they should also accept a List[float].

After PR #10723 the txt-to-img Union pipeline pipeline_controlnet_union_sd_xl.py accepts a list/tuple so each active ControlNet branch can have its own conditioning scale. The in-paint counterpart pipeline_controlnet_union_inpaint_sd_xl.py still contains the old check:

elif isinstance(self.controlnet, ControlNetUnionModel):
    if not isinstance(controlnet_conditioning_scale, float):
        raise TypeError(
            "For single controlnet: `controlnet_conditioning_scale` must be type `float`."
        )

Passing a list raises TypeError: For single controlnet: controlnet_conditioning_scalemust be typefloat.

It was said that it would be added in that PR: #10723 (comment) reply to #10723 (comment)

I don't see any mention of the control_guidance_end behavior in any PR's so maybe it was missed.

Expected Behaviour

Reproduction

from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.utils import load_image
import torch
import numpy as np
from PIL import Image

prompt = "A cat"
# download an image
image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((1024, 1024))
mask = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
).resize((1024, 1024))
# initialize the models and pipeline
controlnet = ControlNetUnionModel.from_pretrained(
    "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetUnionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
)
pipe.enable_model_cpu_offload()
controlnet_img = image.copy()
controlnet_img_np = np.array(controlnet_img)
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
# generate image
image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7], controlnet_conditioning_scale=[1.0], control_guidance_end=[1.0]).images[0]
image.save("inpaint.png")

Logs

Traceback (most recent call last):
  File "source\background_reference_cnet.py", line 146, in <module>
    BackgroundGeneration("images/inputs/", "images/outputs/", "test.jpg", "test.png", "RunDiffusion/Juggernaut-XL-v9")
  File "source\background_reference_cnet.py", line 113, in BackgroundGeneration
    base_result = pipeline(
                  ^^^^^^^^^
  File "source\.venv\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "source\.venv\Lib\site-packages\diffusers\pipelines\controlnet\pipeline_controlnet_union_inpaint_sd_xl.py", line 1372, in __call__
    self.check_inputs(
  File "source\.venv\Lib\site-packages\diffusers\pipelines\controlnet\pipeline_controlnet_union_inpaint_sd_xl.py", line 770, in check_inputs
    raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
TypeError: For single controlnet: `controlnet_conditioning_scale` must be type `float`.

System Info

  • 🤗 Diffusers version: 0.32.2
  • Platform: Windows-11-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.12.10
  • PyTorch version (GPU?): 2.5.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.33.1
  • Transformers version: 4.42.3
  • Accelerate version: 1.4.0
  • PEFT version: 0.9.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX 3500 Ada Generation Laptop GPU, 12282 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions