From f240a936da871ddbde2d855b80b6415f2edda705 Mon Sep 17 00:00:00 2001 From: Anatoly Belikov Date: Tue, 30 Jul 2024 04:49:28 +0300 Subject: [PATCH] handle lora scale and clip skip in lpw sd and sdxl community pipelines (#8988) * handle lora scale and clip skip in lpw sd and sdxl * use StableDiffusionLoraLoaderMixin * use StableDiffusionXLLoraLoaderMixin * style --------- Co-authored-by: Sayak Paul --- examples/community/lpw_stable_diffusion.py | 64 +++++++++++++++++-- examples/community/lpw_stable_diffusion_xl.py | 48 ++++++++++++-- 2 files changed, 100 insertions(+), 12 deletions(-) diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index d57a7c228097..ec27acdce331 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -13,13 +13,17 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, + USE_PEFT_BACKEND, deprecate, logging, + scale_lora_layers, + unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor @@ -199,6 +203,7 @@ def get_unweighted_text_embeddings( text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True, + clip_skip: Optional[int] = None, ): """ When the length of tokens is a multiple of the capacity of the text encoder, @@ -214,7 +219,20 @@ def get_unweighted_text_embeddings( # cover the head and the tail by the starting and the ending tokens text_input_chunk[:, 0] = text_input[0, 0] text_input_chunk[:, -1] = text_input[0, -1] - text_embedding = pipe.text_encoder(text_input_chunk)[0] + if clip_skip is None: + prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device)) + text_embedding = prompt_embeds[0] + else: + prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds) if no_boseos_middle: if i == 0: @@ -230,7 +248,10 @@ def get_unweighted_text_embeddings( text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: - text_embeddings = pipe.text_encoder(text_input)[0] + if clip_skip is None: + clip_skip = 0 + prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)] + text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds) return text_embeddings @@ -242,6 +263,8 @@ def get_weighted_text_embeddings( no_boseos_middle: Optional[bool] = False, skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, + clip_skip=None, + lora_scale=None, ): r""" Prompts can be assigned with local weights using brackets. For example, @@ -268,6 +291,16 @@ def get_weighted_text_embeddings( skip_weighting (`bool`, *optional*, defaults to `False`): Skip the weighting. When the parsing is skipped, it is forced True. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin): + pipe._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) + else: + scale_lora_layers(pipe.text_encoder, lora_scale) max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 if isinstance(prompt, str): prompt = [prompt] @@ -334,10 +367,7 @@ def get_weighted_text_embeddings( # get the embeddings text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.tokenizer.model_max_length, - no_boseos_middle=no_boseos_middle, + pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip ) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) if uncond_prompt is not None: @@ -346,6 +376,7 @@ def get_weighted_text_embeddings( uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, + clip_skip=clip_skip, ) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) @@ -362,6 +393,11 @@ def get_weighted_text_embeddings( current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if pipe.text_encoder is not None: + if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(pipe.text_encoder, lora_scale) + if uncond_prompt is not None: return text_embeddings, uncond_embeddings return text_embeddings, None @@ -549,6 +585,8 @@ def _encode_prompt( max_embeddings_multiples=3, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + clip_skip: Optional[int] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -597,6 +635,8 @@ def _encode_prompt( prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, + clip_skip=clip_skip, + lora_scale=lora_scale, ) if prompt_embeds is None: prompt_embeds = prompt_embeds1 @@ -790,6 +830,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, + clip_skip: Optional[int] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): @@ -865,6 +906,9 @@ def __call__( is_cancelled_callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. If the function returns `True`, the inference will be cancelled. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -903,6 +947,7 @@ def __call__( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None # 3. Encode input prompt prompt_embeds = self._encode_prompt( @@ -914,6 +959,8 @@ def __call__( max_embeddings_multiples, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + lora_scale=lora_scale, ) dtype = prompt_embeds.dtype @@ -1044,6 +1091,7 @@ def text2img( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, + clip_skip=None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): @@ -1101,6 +1149,9 @@ def text2img( is_cancelled_callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. If the function returns `True`, the inference will be cancelled. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -1135,6 +1186,7 @@ def text2img( return_dict=return_dict, callback=callback, is_cancelled_callback=is_cancelled_callback, + clip_skip=clip_skip, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, ) diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py index eaa675d1628f..13d1e2a1156a 100644 --- a/examples/community/lpw_stable_diffusion_xl.py +++ b/examples/community/lpw_stable_diffusion_xl.py @@ -25,21 +25,25 @@ from diffusers.loaders import ( FromSingleFileMixin, IPAdapterMixin, - StableDiffusionLoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( + USE_PEFT_BACKEND, deprecate, is_accelerate_available, is_accelerate_version, is_invisible_watermark_available, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor @@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl( num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, + lora_scale: Optional[int] = None, ): """ This function can process long prompt with weights, no length limitation @@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl( """ device = device or pipe._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin): + pipe._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if pipe.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) + else: + scale_lora_layers(pipe.text_encoder, lora_scale) + + if pipe.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale) + else: + scale_lora_layers(pipe.text_encoder_2, lora_scale) + if prompt_2: prompt = f"{prompt} {prompt_2}" @@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl( bs_embed * num_images_per_prompt, -1 ) + if pipe.text_encoder is not None: + if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(pipe.text_encoder, lora_scale) + + if pipe.text_encoder_2 is not None: + if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(pipe.text_encoder_2, lora_scale) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline( StableDiffusionMixin, FromSingleFileMixin, IPAdapterMixin, - StableDiffusionLoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ): r""" @@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline( The pipeline also inherits the following loading methods: - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings Args: @@ -743,7 +776,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): @@ -1612,7 +1645,9 @@ def __call__( image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 3. Encode input prompt - (self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None) + lora_scale = ( + self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None + ) negative_prompt = negative_prompt if negative_prompt is not None else "" @@ -1627,6 +1662,7 @@ def __call__( neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, + lora_scale=lora_scale, ) dtype = prompt_embeds.dtype