diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index df7dfbcd8871..5d89658830f1 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1267,6 +1267,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, for adapter_name in adapter_names: unet_module.lora_A[adapter_name].to(device) unet_module.lora_B[adapter_name].to(device) + # this is a param, not a module, so device placement is not in-place -> re-assign + unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[ + adapter_name + ].to(device) # Handle the text encoder modules_to_process = [] @@ -1283,6 +1287,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, for adapter_name in adapter_names: text_encoder_module.lora_A[adapter_name].to(device) text_encoder_module.lora_B[adapter_name].to(device) + # this is a param, not a module, so device placement is not in-place -> re-assign + text_encoder_module.lora_magnitude_vector[ + adapter_name + ] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device) class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index ebf46e396284..fc28d94c240b 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -150,6 +150,54 @@ def test_integration_move_lora_cpu(self): if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)): self.assertTrue(m.weight.device != torch.device("cpu")) + @require_torch_gpu + def test_integration_move_lora_dora_cpu(self): + from peft import LoraConfig + + path = "runwayml/stable-diffusion-v1-5" + unet_lora_config = LoraConfig( + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + use_dora=True, + ) + text_lora_config = LoraConfig( + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + use_dora=True, + ) + + pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), + "Lora not correctly set in text encoder", + ) + + self.assertTrue( + check_if_lora_correctly_set(pipe.unet), + "Lora not correctly set in text encoder", + ) + + for name, param in pipe.unet.named_parameters(): + if "lora_" in name: + self.assertEqual(param.device, torch.device("cpu")) + + for name, param in pipe.text_encoder.named_parameters(): + if "lora_" in name: + self.assertEqual(param.device, torch.device("cpu")) + + pipe.set_lora_device(["adapter-1"], torch_device) + + for name, param in pipe.unet.named_parameters(): + if "lora_" in name: + self.assertNotEqual(param.device, torch.device("cpu")) + + for name, param in pipe.text_encoder.named_parameters(): + if "lora_" in name: + self.assertNotEqual(param.device, torch.device("cpu")) + @slow @require_torch_gpu