From b09a2aa308f9e0a9ec06cb82f1f2a7bccf21a787 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 19 Mar 2024 17:53:38 +0530 Subject: [PATCH] [LoRA] fix `cross_attention_kwargs` problems and tighten tests (#7388) * debugging * let's see the numbers * let's see the numbers * let's see the numbers * restrict tolerance. * increase inference steps. * shallow copy of cross_attentionkwargs * remove print --- src/diffusers/models/unets/unet_2d_condition.py | 1 + tests/lora/test_lora_layers_peft.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index c0db7df2ec16..9a710919d067 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1178,6 +1178,7 @@ def forward( # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() lora_scale = cross_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 67d28fe19e7e..95689b71ca49 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -158,7 +158,7 @@ def get_dummy_inputs(self, with_generator=True): pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", - "num_inference_steps": 2, + "num_inference_steps": 5, "guidance_scale": 6.0, "output_type": "np", } @@ -589,7 +589,7 @@ def test_simple_inference_with_text_unet_lora_and_scale(self): **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} ).images self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + not np.allclose(output_lora, output_lora_scale, atol=1e-4, rtol=1e-4), "Lora + scale should change the output", ) @@ -1300,6 +1300,11 @@ def test_integration_logits_with_scale(self): pipe.load_lora_weights(lora_id) pipe = pipe.to("cuda") + self.assertTrue( + self.check_if_lora_correctly_set(pipe.unet), + "Lora not correctly set in UNet", + ) + self.assertTrue( self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder 2",