Skip to content

Commit

Permalink
[LoRA] feat: save_lora_adapter() (huggingface#9862)
Browse files Browse the repository at this point in the history
* feat: save_lora_adapter.
  • Loading branch information
sayakpaul authored Nov 19, 2024
1 parent acf479b commit 7d0b9c4
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 55 deletions.
6 changes: 4 additions & 2 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,9 @@ def load_lora_into_unet(
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
Expand Down Expand Up @@ -827,8 +828,9 @@ def load_lora_into_unet(
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
Expand Down
104 changes: 88 additions & 16 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Union

import safetensors
import torch
import torch.nn as nn

from ..utils import (
Expand Down Expand Up @@ -189,40 +193,45 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")

keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(prefix)]
if len(transformer_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
if prefix is not None:
keys = list(state_dict.keys())
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
if len(model_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}

if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
)

if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)

if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)

rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]

if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)

# adapter_name
Expand Down Expand Up @@ -276,6 +285,69 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

def save_lora_adapter(
self,
save_directory,
adapter_name: str = "default",
upcast_before_saving: bool = False,
safe_serialization: bool = True,
weight_name: Optional[str] = None,
):
"""
Save the LoRA parameters corresponding to the underlying model.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
"""
from peft.utils import get_peft_model_state_dict

from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE

if adapter_name is None:
adapter_name = get_adapter_name(self)

if adapter_name not in getattr(self, "peft_config", {}):
raise ValueError(f"Adapter name {adapter_name} not found in the model.")

lora_layers_to_save = get_peft_model_state_dict(
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
)
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")

if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})

else:
save_function = torch.save

os.makedirs(save_directory, exist_ok=True)

if weight_name is None:
if safe_serialization:
weight_name = LORA_WEIGHT_NAME_SAFE
else:
weight_name = LORA_WEIGHT_NAME

# TODO: we could consider saving the `peft_config` as well.
save_path = Path(save_directory, weight_name).as_posix()
save_function(lora_layers_to_save, save_path)
logger.info(f"Model weights saved in {save_path}")

def set_adapters(
self,
adapter_names: Union[List[str], str],
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
USE_PEFT_BACKEND,
_get_model_file,
convert_unet_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
Expand Down Expand Up @@ -209,6 +210,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
is_model_cpu_offload = False
is_sequential_cpu_offload = False

if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
deprecate("load_attn_procs", "0.40.0", deprecation_message)

if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
Expand Down
12 changes: 2 additions & 10 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,11 +1784,7 @@ def test_missing_keys_warning(self):
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]

logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.peft")
)
logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
Expand Down Expand Up @@ -1823,11 +1819,7 @@ def test_unexpected_keys_warning(self):
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)

logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.peft")
)
logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
Expand Down
105 changes: 103 additions & 2 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from diffusers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME,
is_peft_available,
is_torch_npu_available,
is_xformers_available,
logging,
Expand All @@ -65,6 +66,10 @@
from ..others.test_utils import TOKEN, USER, is_staging_test


if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer


def caculate_expected_num_shards(index_map_path):
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
Expand All @@ -74,6 +79,16 @@ def caculate_expected_num_shards(index_map_path):
return expected_num_shards


def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False


# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None
Expand Down Expand Up @@ -877,8 +892,6 @@ def _set_gradient_checkpointing_new(self, module, value=False):
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()

print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")

assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"

Expand All @@ -902,6 +915,94 @@ def test_deprecated_kwargs(self):
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)

@parameterized.expand([True, False])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_save_load_lora_adapter(self, use_dora=False):
import safetensors
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict

from diffusers.loaders.peft import PeftAdapterMixin

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
return

torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]

denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]

self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))

with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))

model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")

for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
self.assertTrue(torch.allclose(loaded_v, retrieved_v))

self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]

self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))

@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_wrong_adapter_name_raises_error(self):
from peft import LoraConfig

from diffusers.loaders.peft import PeftAdapterMixin

init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
return

denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with self.assertRaises(ValueError) as err_context:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)

self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))

@require_torch_gpu
def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down
Loading

0 comments on commit 7d0b9c4

Please sign in to comment.