From dac623b59f52c58383a39207d5147aa34e0047cd Mon Sep 17 00:00:00 2001 From: Eliseu Silva Date: Fri, 8 Nov 2024 22:40:51 -0300 Subject: [PATCH] Feature IP Adapter Xformers Attention Processor (#9881) * Feature IP Adapter Xformers Attention Processor: this fix error loading incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: #8863 #8872 --- src/diffusers/loaders/ip_adapter.py | 14 +- src/diffusers/loaders/unet.py | 13 +- src/diffusers/models/attention_processor.py | 262 +++++++++++++++++++- 3 files changed, 278 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 1006dab9e4b9..49b46c4fc615 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -33,16 +33,14 @@ if is_transformers_available(): - from transformers import ( - CLIPImageProcessor, - CLIPVisionModelWithProjection, - ) + from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, ) logger = logging.get_logger(__name__) @@ -284,7 +282,9 @@ def set_ip_adapter_scale(self, scale): scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) for attn_name, attn_processor in unet.attn_processors.items(): - if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + if isinstance( + attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ): if len(scale_configs) != len(attn_processor.scale): raise ValueError( f"Cannot assign {len(scale_configs)} scale_configs to " @@ -342,7 +342,9 @@ def unload_ip_adapter(self): ) attn_procs[name] = ( attn_processor_class - if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) + if isinstance( + value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor) + ) else value.__class__() ) self.unet.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 2fa7732a6a3b..b37b681ae8fe 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -765,6 +765,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F from ..models.attention_processor import ( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, ) if low_cpu_mem_usage: @@ -804,11 +805,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = self.attn_processors[name].__class__ attn_procs[name] = attn_processor_class() - else: - attn_processor_class = ( - IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor - ) + if "XFormers" in str(self.attn_processors[name].__class__): + attn_processor_class = IPAdapterXFormersAttnProcessor + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor + ) num_image_text_embeds = [] for state_dict in state_dicts: if "proj.weight" in state_dict["image_proj"]: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index da01b7a1edcd..772aae7fcd2f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -318,7 +318,10 @@ def set_use_memory_efficient_attention_xformers( XFormersAttnAddedKVProcessor, ), ) - + is_ip_adapter = hasattr(self, "processor") and isinstance( + self.processor, + (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor), + ) if use_memory_efficient_attention_xformers: if is_added_kv_processor and is_custom_diffusion: raise NotImplementedError( @@ -368,6 +371,19 @@ def set_use_memory_efficient_attention_xformers( "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." ) processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + elif is_ip_adapter: + processor = IPAdapterXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: @@ -386,6 +402,18 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_ip_adapter: + processor = IPAdapterAttnProcessor2_0( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + num_tokens=self.processor.num_tokens, + scale=self.processor.scale, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_ip"): + processor.to( + device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype + ) else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses @@ -4542,6 +4570,238 @@ def __call__( return hidden_states +class IPAdapterXFormersAttnProcessor(torch.nn.Module): + r""" + Attention processor for IP-Adapter using xFormers. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or `List[float]`, defaults to 1.0): + the weight scale of image prompt. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + num_tokens=(4,), + scale=1.0, + attention_op: Optional[Callable] = None, + ): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.attention_op = attention_op + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.FloatTensor] = None, + ): + residual = hidden_states + + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if ip_hidden_states: + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate( + zip(ip_adapter_masks, self.scale, ip_hidden_states) + ): + if mask is None: + continue + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + mask = mask.to(torch.float16) + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = attn.head_to_batch_dim(ip_key).contiguous() + ip_value = attn.head_to_batch_dim(ip_value).contiguous() + + _current_ip_hidden_states = xformers.ops.memory_efficient_attention( + query, ip_key, ip_value, op=self.attention_op + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key).contiguous() + ip_value = attn.head_to_batch_dim(ip_value).contiguous() + + current_ip_hidden_states = xformers.ops.memory_efficient_attention( + query, ip_key, ip_value, op=self.attention_op + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + + hidden_states = hidden_states + scale * current_ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class PAGIdentitySelfAttnProcessor2_0: r""" Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).