Skip to content

Commit

Permalink
Added Support of Xlabs controlnet to FluxControlNetInpaintPipeline (h…
Browse files Browse the repository at this point in the history
…uggingface#9770)

* added xlabs support
  • Loading branch information
SahilCarterr authored Oct 25, 2024
1 parent 73b59f5 commit 298ab6e
Showing 1 changed file with 34 additions and 27 deletions.
61 changes: 34 additions & 27 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,19 +932,22 @@ def __call__(
)
height, width = control_image.shape[-2:]

# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
if self.controlnet.input_hint_block is None:
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)

# set control mode
if control_mode is not None:
Expand All @@ -954,7 +957,9 @@ def __call__(
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []

for control_image_ in control_image:
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
for i, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
Expand All @@ -966,19 +971,20 @@ def __call__(
)
height, width = control_image_.shape[-2:]

# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
if self.controlnet.nets[0].input_hint_block is None:
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor

# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)

control_images.append(control_image_)

Expand Down Expand Up @@ -1129,6 +1135,7 @@ def __call__(
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]

# compute the previous noisy sample x_t -> x_t-1
Expand Down

0 comments on commit 298ab6e

Please sign in to comment.