Skip to content

Commit

Permalink
handle lora scale and clip skip in lpw sd and sdxl community pipelines (
Browse files Browse the repository at this point in the history
huggingface#8988)

* handle lora scale and clip skip in lpw sd and sdxl

* use StableDiffusionLoraLoaderMixin

* use StableDiffusionXLLoraLoaderMixin

* style

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
noskill and sayakpaul authored Jul 30, 2024
1 parent 00d8d46 commit f240a93
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 12 deletions.
64 changes: 58 additions & 6 deletions examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor

Expand Down Expand Up @@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
clip_skip: Optional[int] = None,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
Expand All @@ -214,7 +219,20 @@ def get_unweighted_text_embeddings(
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk)[0]
if clip_skip is None:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
text_embedding = prompt_embeds[0]
else:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)

if no_boseos_middle:
if i == 0:
Expand All @@ -230,7 +248,10 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
if clip_skip is None:
clip_skip = 0
prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
return text_embeddings


Expand All @@ -242,6 +263,8 @@ def get_weighted_text_embeddings(
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
clip_skip=None,
lora_scale=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
Expand All @@ -268,6 +291,16 @@ def get_weighted_text_embeddings(
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):
pipe._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
Expand Down Expand Up @@ -334,10 +367,7 @@ def get_weighted_text_embeddings(

# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
if uncond_prompt is not None:
Expand All @@ -346,6 +376,7 @@ def get_weighted_text_embeddings(
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip,
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)

Expand All @@ -362,6 +393,11 @@ def get_weighted_text_embeddings(
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)

if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)

if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None
Expand Down Expand Up @@ -549,6 +585,8 @@ def _encode_prompt(
max_embeddings_multiples=3,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand Down Expand Up @@ -597,6 +635,8 @@ def _encode_prompt(
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
if prompt_embeds is None:
prompt_embeds = prompt_embeds1
Expand Down Expand Up @@ -790,6 +830,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip: Optional[int] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -865,6 +906,9 @@ def __call__(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Expand Down Expand Up @@ -903,6 +947,7 @@ def __call__(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None

# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
Expand All @@ -914,6 +959,8 @@ def __call__(
max_embeddings_multiples,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
dtype = prompt_embeds.dtype

Expand Down Expand Up @@ -1044,6 +1091,7 @@ def text2img(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip=None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -1101,6 +1149,9 @@ def text2img(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Expand Down Expand Up @@ -1135,6 +1186,7 @@ def text2img(
return_dict=return_dict,
callback=callback,
is_cancelled_callback=is_cancelled_callback,
clip_skip=clip_skip,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
)
Expand Down
48 changes: 42 additions & 6 deletions examples/community/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,25 @@
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor

Expand Down Expand Up @@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl(
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None,
):
"""
This function can process long prompt with weights, no length limitation
Expand All @@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl(
"""
device = device or pipe._execution_device

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
pipe._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if pipe.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)

if pipe.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
else:
scale_lora_layers(pipe.text_encoder_2, lora_scale)

if prompt_2:
prompt = f"{prompt} {prompt_2}"

Expand Down Expand Up @@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl(
bs_embed * num_images_per_prompt, -1
)

if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)

if pipe.text_encoder_2 is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds


Expand Down Expand Up @@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
StableDiffusionMixin,
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
):
r"""
Expand All @@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
Args:
Expand Down Expand Up @@ -743,7 +776,7 @@ def encode_prompt(

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale

if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1612,7 +1645,9 @@ def __call__(
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 3. Encode input prompt
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
lora_scale = (
self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
)

negative_prompt = negative_prompt if negative_prompt is not None else ""

Expand All @@ -1627,6 +1662,7 @@ def __call__(
neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
dtype = prompt_embeds.dtype

Expand Down

0 comments on commit f240a93

Please sign in to comment.