Skip to content

Commit

Permalink
FIX Setting device for DoRA parameters (huggingface#7655)
Browse files Browse the repository at this point in the history
Fix a bug that causes the the call to set_lora_device to ignore the DoRA
parameters.
  • Loading branch information
BenjaminBossan authored Apr 12, 2024
1 parent 279de3c commit 2523390
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
for adapter_name in adapter_names:
unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
adapter_name
].to(device)

# Handle the text encoder
modules_to_process = []
Expand All @@ -1283,6 +1287,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
for adapter_name in adapter_names:
text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
text_encoder_module.lora_magnitude_vector[
adapter_name
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)


class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
Expand Down
48 changes: 48 additions & 0 deletions tests/lora/test_lora_layers_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,54 @@ def test_integration_move_lora_cpu(self):
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))

@require_torch_gpu
def test_integration_move_lora_dora_cpu(self):
from peft import LoraConfig

path = "runwayml/stable-diffusion-v1-5"
unet_lora_config = LoraConfig(
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
use_dora=True,
)
text_lora_config = LoraConfig(
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
use_dora=True,
)

pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")

self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)

self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in text encoder",
)

for name, param in pipe.unet.named_parameters():
if "lora_" in name:
self.assertEqual(param.device, torch.device("cpu"))

for name, param in pipe.text_encoder.named_parameters():
if "lora_" in name:
self.assertEqual(param.device, torch.device("cpu"))

pipe.set_lora_device(["adapter-1"], torch_device)

for name, param in pipe.unet.named_parameters():
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))

for name, param in pipe.text_encoder.named_parameters():
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))


@slow
@require_torch_gpu
Expand Down

0 comments on commit 2523390

Please sign in to comment.