Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def __call__(
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
Expand All @@ -462,7 +462,12 @@ def __call__(
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
generate images that are closely linked to the text `prompt`, usually at the expense of lower image
quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand All @@ -474,17 +479,16 @@ def __call__(
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.

This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -564,6 +568,16 @@ def __call__(
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)

do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
Expand Down Expand Up @@ -618,10 +632,17 @@ def __call__(
self._num_timesteps = len(timesteps)

# handle guidance
if self.transformer.config.guidance_embeds:
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None

if self.attention_kwargs is None:
Expand Down
46 changes: 36 additions & 10 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def __call__(
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
guidance_scale: Optional[float] = None,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_image: PipelineImageInput = None,
Expand Down Expand Up @@ -566,7 +566,12 @@ def __call__(
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
generate images that are closely linked to the text `prompt`, usually at the expense of lower image
quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand All @@ -578,12 +583,16 @@ def __call__(
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -674,6 +683,16 @@ def __call__(
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)

do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
Expand Down Expand Up @@ -822,10 +841,17 @@ def __call__(
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps)

# handle guidance
if self.transformer.config.guidance_embeds:
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None

if self.attention_kwargs is None:
Expand Down
51 changes: 36 additions & 15 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __call__(
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
Expand All @@ -559,7 +559,12 @@ def __call__(
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
lower image quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand All @@ -571,17 +576,16 @@ def __call__(
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.

This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -672,6 +676,16 @@ def __call__(
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)

if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)

do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
image=prompt_image,
Expand Down Expand Up @@ -734,10 +748,17 @@ def __call__(
self._num_timesteps = len(timesteps)

# handle guidance
if self.transformer.config.guidance_embeds:
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None

if self.attention_kwargs is None:
Expand Down
Loading
Loading