Skip to content

Commit

Permalink
fix bug, not refine
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold committed Jul 3, 2024
1 parent 2ad3be0 commit b93f544
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 48 deletions.
2 changes: 2 additions & 0 deletions onediff_diffusers_extensions/onediffx/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
set_and_fuse_adapters,
delete_adapters,
get_active_adapters,
# fuse_lora,
load_lora_and_optionally_fuse,
)

from onediff.infer_compiler.backends.oneflow.param_utils import update_graph_with_constant_folding_info
76 changes: 71 additions & 5 deletions onediff_diffusers_extensions/onediffx/lora/lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
import warnings
from typing import Optional, Union, Dict, Tuple, List
from collections import OrderedDict, defaultdict
from packaging import version
Expand Down Expand Up @@ -30,6 +31,10 @@
import peft
is_onediffx_lora_available = version.parse(diffusers.__version__) >= version.parse("0.19.3")

class OneDiffXWarning(Warning):
pass

warnings.filterwarnings("always", category=OneDiffXWarning)

USE_PEFT_BACKEND = False

Expand All @@ -43,15 +48,59 @@ def load_and_fuse_lora(
offload_device="cuda",
use_cache=False,
**kwargs,
):
return load_lora_and_optionally_fuse(
pipeline,
pretrained_model_name_or_path_or_dict,
adapter_name,
lora_scale=lora_scale,
offload_device=offload_device,
use_cache=use_cache,
fuse=True,
**kwargs,
)

def load_lora_and_optionally_fuse(
pipeline: LoraLoaderMixin,
pretrained_model_name_or_path_or_dict: Union[str, Path, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
*,
fuse,
lora_scale: Optional[float] = None,
offload_device="cuda",
use_cache=False,
**kwargs,
) -> None:
if not is_onediffx_lora_available:
raise RuntimeError(
"onediffx.lora only supports diffusers of at least version 0.19.3"
)

_init_adapters_info(pipeline)

if not fuse and lora_scale is not None:
warnings.warn("When fuse=False, the lora_scale will be ignored and set to 1.0 as default", category=OneDiffXWarning)
lora_scale = 1.0

if fuse and len(pipeline._active_adapter_names) > 0:
warnings.warn(
"The current API is supported for operating with a single LoRA file. "
"You are trying to load and fuse more than one LoRA "
"which is not well-supported and may lead to accuracy issues.",
category=OneDiffXWarning,
)

if adapter_name is None:
adapter_name = create_adapter_names(pipeline)

if adapter_name in pipeline._adapter_names:
warnings.warn(f"adapter_name {adapter_name} already exists, will be ignored", category=OneDiffXWarning)
return

pipeline._adapter_names.add(adapter_name)
pipeline._active_adapter_names[adapter_name] = 1.0

if fuse:
pipeline._active_adapter_names[adapter_name] = lora_scale

self = pipeline

Expand Down Expand Up @@ -84,6 +133,7 @@ def load_and_fuse_lora(
lora_scale=lora_scale,
offload_device=offload_device,
use_cache=use_cache,
fuse=fuse,
)

# load lora weights into text encoder
Expand All @@ -98,6 +148,7 @@ def load_and_fuse_lora(
lora_scale=lora_scale,
adapter_name=adapter_name,
_pipeline=self,
fuse=fuse,
)

text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
Expand All @@ -111,6 +162,7 @@ def load_and_fuse_lora(
lora_scale=lora_scale,
adapter_name=adapter_name,
_pipeline=self,
fuse=fuse,
)


Expand All @@ -123,7 +175,7 @@ def _unfuse_lora_apply(m: torch.nn.Module):
):
_unfuse_lora(m.base_layer)

pipeline._adapter_names.clear()
# pipeline._adapter_names.clear()
pipeline._active_adapter_names.clear()

pipeline.unet.apply(_unfuse_lora_apply)
Expand All @@ -135,9 +187,14 @@ def _unfuse_lora_apply(m: torch.nn.Module):

def set_and_fuse_adapters(
pipeline: LoraLoaderMixin,
adapter_names: Union[List[str], str],
adapter_names: Optional[Union[List[str], str]] = None,
adapter_weights: Optional[List[float]] = None,
):
if not hasattr(pipeline, "_adapter_names"):
raise RuntimeError("Didn't find any LoRA, please load LoRA first")
if adapter_names is None:
adapter_names = pipeline.active_adapter_names

if isinstance(adapter_names, str):
adapter_names = [adapter_names]

Expand All @@ -146,8 +203,10 @@ def set_and_fuse_adapters(
elif isinstance(adapter_weights, float):
adapter_weights = [adapter_weights, ] * len(adapter_names)

_init_adapters_info(pipeline)
pipeline._adapter_names |= set(adapter_names)
# _init_adapters_info(pipeline)
adapter_names = [x for x in adapter_names if x in pipeline._adapter_names]
# pipeline._adapter_names |= set(adapter_names)
# pipeline._adapter_names |= set(adapter_names)
pipeline._active_adapter_names = {k: v for k, v in zip(adapter_names, adapter_weights)}

def set_adapters_apply(m):
Expand Down Expand Up @@ -243,3 +302,10 @@ def load_state_dict_cached(


CachedLoRAs = LRUCacheDict(100)

def create_adapter_names(pipe):
for i in range(0, 10000):
result = f"default_{i}"
if result not in pipe._adapter_names:
return result
raise RuntimeError("Too much LoRA loaded")
49 changes: 28 additions & 21 deletions onediff_diffusers_extensions/onediffx/lora/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from onediff.utils import logger

from .utils import fuse_lora, get_adapter_names
from .utils import _fuse_lora, get_adapter_names, _load_lora_and_optionally_fuse

USE_PEFT_BACKEND = False

Expand All @@ -39,6 +39,7 @@ def load_lora_into_text_encoder(
low_cpu_mem_usage=None,
adapter_name=None,
_pipeline=None,
fuse: bool = False,
):
"""
This will load and fuse the LoRA layers specified in `state_dict` into `text_encoder`
Expand Down Expand Up @@ -71,20 +72,20 @@ def load_lora_into_text_encoder(
low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
)

if adapter_name is None:
adapter_name = get_adapter_names(text_encoder)

if hasattr(text_encoder, "adapter_names"):
if adapter_name in text_encoder.adapter_names:
raise ValueError(
f"[OneDiffX load_lora_into_text_encoder] The adapter name {adapter_name} already exists in text_encoder"
)
else:
text_encoder.adapter_name.add(adapter_name)
text_encoder.active_adapter_name[adapter_name] = 1.0
else:
text_encoder.adapter_name = set([adapter_name])
text_encoder.active_adapter_name = {adapter_name: 1.0}
# if adapter_name is None:
# adapter_name = get_adapter_names(text_encoder)

# if hasattr(text_encoder, "adapter_names"):
# if adapter_name in text_encoder.adapter_names:
# raise ValueError(
# f"[OneDiffX load_lora_into_text_encoder] The adapter name {adapter_name} already exists in text_encoder"
# )
# else:
# text_encoder.adapter_name.add(adapter_name)
# text_encoder.active_adapter_name[adapter_name] = 1.0
# else:
# text_encoder.adapter_name = set([adapter_name])
# text_encoder.active_adapter_name = {adapter_name: 1.0}

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
Expand Down Expand Up @@ -204,41 +205,45 @@ def load_lora_into_text_encoder(
else:
current_rank = rank

fuse_lora(
_load_lora_and_optionally_fuse(
attn_module.q_proj,
te_lora_grouped_dict.pop(f"{name}.q_proj"),
lora_scale,
query_alpha,
current_rank,
adapter_name=adapter_name,
prefix="lora_linear_layer",
fuse=fuse,
)
fuse_lora(
_load_lora_and_optionally_fuse(
attn_module.k_proj,
te_lora_grouped_dict.pop(f"{name}.k_proj"),
lora_scale,
key_alpha,
current_rank,
adapter_name=adapter_name,
prefix="lora_linear_layer",
fuse=fuse,
)
fuse_lora(
_load_lora_and_optionally_fuse(
attn_module.v_proj,
te_lora_grouped_dict.pop(f"{name}.v_proj"),
lora_scale,
value_alpha,
current_rank,
adapter_name=adapter_name,
prefix="lora_linear_layer",
fuse=fuse,
)
fuse_lora(
_load_lora_and_optionally_fuse(
attn_module.out_proj,
te_lora_grouped_dict.pop(f"{name}.out_proj"),
lora_scale,
out_alpha,
current_rank,
adapter_name=adapter_name,
prefix="lora_linear_layer",
fuse=fuse,
)

if patch_mlp:
Expand All @@ -253,23 +258,25 @@ def load_lora_into_text_encoder(
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")

fuse_lora(
_load_lora_and_optionally_fuse(
mlp_module.fc1,
te_lora_grouped_dict.pop(f"{name}.fc1"),
lora_scale,
fc1_alpha,
current_rank_fc1,
adapter_name=adapter_name,
prefix="lora_linear_layer",
fuse=fuse,
)
fuse_lora(
_load_lora_and_optionally_fuse(
mlp_module.fc2,
te_lora_grouped_dict.pop(f"{name}.fc2"),
lora_scale,
fc2_alpha,
current_rank_fc2,
adapter_name=adapter_name,
prefix="lora_linear_layer",
fuse=fuse,
)

if is_network_alphas_populated and len(network_alphas) > 0:
Expand Down
39 changes: 22 additions & 17 deletions onediff_diffusers_extensions/onediffx/lora/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
LoRACompatibleConv,
LoRACompatibleLinear,
)
from .utils import fuse_lora, get_adapter_names, is_peft_available
from .utils import _fuse_lora, get_adapter_names, is_peft_available, _load_lora_and_optionally_fuse
from diffusers.utils import is_accelerate_available

if is_peft_available():
Expand All @@ -29,24 +29,25 @@ def load_lora_into_unet(
adapter_name=None,
_pipeline=None,
*,
fuse: bool = False,
lora_scale: float = 1.0,
offload_device="cpu",
use_cache=False,
):
if adapter_name is None:
adapter_name = get_adapter_names(unet)

if hasattr(unet, "adapter_names"):
if adapter_name in unet.adapter_names:
raise ValueError(
f"[OneDiffX load_lora_into_unet] The adapter name {adapter_name} already exists in UNet"
)
else:
unet.adapter_name.add(adapter_name)
unet.active_adapter_name[adapter_name] = 1.0
else:
unet.adapter_name = set([adapter_name])
unet.active_adapter_name = {adapter_name: 1.0}
# if adapter_name is None:
# adapter_name = get_adapter_names(unet)

# if hasattr(unet, "adapter_names"):
# if adapter_name in unet.adapter_names:
# raise ValueError(
# f"[OneDiffX load_lora_into_unet] The adapter name {adapter_name} already exists in UNet"
# )
# else:
# unet.adapter_name.add(adapter_name)
# unet.active_adapter_name[adapter_name] = 1.0
# else:
# unet.adapter_name = set([adapter_name])
# unet.active_adapter_name = {adapter_name: 1.0}

keys = list(state_dict.keys())
cls = type(self)
Expand Down Expand Up @@ -92,6 +93,7 @@ def load_lora_into_unet(
lora_scale=lora_scale,
offload_device=offload_device,
use_cache=use_cache,
fuse=fuse,
)


Expand Down Expand Up @@ -149,6 +151,7 @@ def _load_attn_procs(
_pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None)
adapter_name = kwargs.pop("adapter_name", None)
fuse = kwargs.pop("fuse", False)
state_dict = pretrained_model_name_or_path_or_dict

is_network_alphas_none = network_alphas is None
Expand Down Expand Up @@ -226,27 +229,29 @@ def _load_attn_procs(
torch.nn.Linear,
),
):
fuse_lora(
_load_lora_and_optionally_fuse(
attn_processor,
value_dict,
lora_scale,
mapped_network_alphas.get(key),
rank,
offload_device=offload_device,
adapter_name=adapter_name,
fuse=fuse,
)
elif is_peft_available() and isinstance(
attn_processor,
(peft.tuners.lora.layer.Linear, peft.tuners.lora.layer.Conv2d),
):
fuse_lora(
_load_lora_and_optionally_fuse(
attn_processor.base_layer,
value_dict,
lora_scale,
mapped_network_alphas.get(key),
rank,
offload_device=offload_device,
adapter_name=adapter_name,
fuse=fuse,
)
else:
raise ValueError(
Expand Down
Loading

0 comments on commit b93f544

Please sign in to comment.