From 802bd1a6ef575c396dedddc563b852ec831116af Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 25 Jun 2025 17:47:53 +0530 Subject: [PATCH 1/5] feat: use exclude modules to loraconfig. --- src/diffusers/loaders/peft.py | 13 +++++++--- src/diffusers/utils/peft_utils.py | 41 +++++++++++++++++++++++++------ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 343623071340..3da1c4454f22 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -243,13 +243,20 @@ def load_lora_adapter( k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys } - # create LoraConfig - lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank) - # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) + # create LoraConfig + lora_config = _create_lora_config( + state_dict, + network_alphas, + metadata, + rank, + model_state_dict=self.state_dict(), + adapter_name=adapter_name, + ) + # None: def _create_lora_config( - state_dict, - network_alphas, - metadata, - rank_pattern_dict, - is_unet: bool = True, + state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None ): from peft import LoraConfig @@ -306,7 +309,12 @@ def _create_lora_config( lora_config_kwargs = metadata else: lora_config_kwargs = get_peft_kwargs( - rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet + rank_pattern_dict, + network_alpha_dict=network_alphas, + peft_state_dict=state_dict, + is_unet=is_unet, + model_state_dict=model_state_dict, + adapter_name=adapter_name, ) _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) @@ -371,3 +379,20 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): if warn_msg: logger.warning(warn_msg) + + +def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None): + all_modules = set() + string_to_replace = f"{adapter_name}." if adapter_name else "" + + for name in model_state_dict.keys(): + if string_to_replace: + name = name.replace(string_to_replace, "") + if "." in name: + module_name = name.rsplit(".", 1)[0] + all_modules.add(module_name) + + target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} + exclude_modules = list(all_modules - target_modules_set) + + return exclude_modules From 76356ea13e06893f55f5da4b4597c870c7366c0e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Jun 2025 16:32:17 +0530 Subject: [PATCH 2/5] version-guard. --- src/diffusers/utils/peft_utils.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 46e3f81e4556..9d04d58f5d7c 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -187,21 +187,25 @@ def get_peft_kwargs( # for now we know that the "bias" keys are only associated with `lora_B`. lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) - # Example: load FusionX LoRA into Wan VACE - exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name) - if not exclude_modules: - exclude_modules = None - lora_config_kwargs = { "r": r, "lora_alpha": lora_alpha, "rank_pattern": rank_pattern, "alpha_pattern": alpha_pattern, "target_modules": target_modules, - "exclude_modules": exclude_modules, "use_dora": use_dora, "lora_bias": lora_bias, } + + # Example: try load FusionX LoRA into Wan VACE + exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name) + if exclude_modules: + if not is_peft_version(">=", "0.14.0"): + msg = "It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft` version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U peft`." + logger.warning(msg) + else: + lora_config_kwargs.update({"exclude_modules": exclude_modules}) + return lora_config_kwargs @@ -382,6 +386,11 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None): + """ + Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the + `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it + doesn't exist in `peft_state_dict`. + """ all_modules = set() string_to_replace = f"{adapter_name}." if adapter_name else "" From 04d6ddbe6a1e512bce4f6c9de1d1bfed25b8b697 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Jun 2025 17:59:04 +0530 Subject: [PATCH 3/5] tests and version guard. --- src/diffusers/utils/peft_utils.py | 3 ++ tests/lora/utils.py | 59 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 9d04d58f5d7c..6551602b3f8f 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -391,6 +391,8 @@ def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it doesn't exist in `peft_state_dict`. """ + if model_state_dict is None: + return all_modules = set() string_to_replace = f"{adapter_name}." if adapter_name else "" @@ -402,6 +404,7 @@ def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None all_modules.add(module_name) target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} + print(f"{target_modules_set=}") exclude_modules = list(all_modules - target_modules_set) return exclude_modules diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 93dc4a2c37e3..059b925f15ad 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import os import re @@ -290,6 +291,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): return modules_to_save + def _get_exclude_modules(self, pipe): + from diffusers.utils.peft_utils import _derive_exclude_modules + + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + denoiser = "unet" if self.unet_kwargs is not None else "transformer" + modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} + denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] + pipe.unload_lora_weights() + denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() + exclude_modules = _derive_exclude_modules( + denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" + ) + return exclude_modules + def check_if_adapters_added_correctly( self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default" ): @@ -2308,6 +2323,50 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." ) + @require_peft_version_greater("0.13.2") + def test_lora_exclude_modules(self): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + # only supported for `denoiser` now + pipe_cp = copy.deepcopy(pipe) + pipe_cp, _ = self.check_if_adapters_added_correctly( + pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + denoiser_exclude_modules = self._get_exclude_modules(pipe_cp) + pipe_cp.to("cpu") + del pipe_cp + + denoiser_lora_config.exclude_modules = denoiser_exclude_modules + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdir) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), + "LoRA should change outputs.", + ) + self.assertTrue( + np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), + "Lora outputs should match.", + ) + def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." for scheduler_cls in self.scheduler_classes: From eb8d78a94dc377fde445a68b62bc16626f8d1689 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Jun 2025 18:13:51 +0530 Subject: [PATCH 4/5] remove print. --- src/diffusers/utils/peft_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 6551602b3f8f..dbb0b5a8d684 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -404,7 +404,6 @@ def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None all_modules.add(module_name) target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} - print(f"{target_modules_set=}") exclude_modules = list(all_modules - target_modules_set) return exclude_modules From 651b807d31a96bf9c5ba61f0f9928147f6531531 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 27 Jun 2025 18:16:36 +0530 Subject: [PATCH 5/5] describe the test --- tests/lora/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 059b925f15ad..4a6d0f1e2c19 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2325,6 +2325,14 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): @require_peft_version_greater("0.13.2") def test_lora_exclude_modules(self): + """ + Test to check if `exclude_modules` works or not. It works in the following way: + we first create a pipeline and insert LoRA config into it. We then derive a `set` + of modules to exclude by investigating its denoiser state dict and denoiser LoRA + state dict. + + We then create a new LoRA config to include the `exclude_modules` and perform tests. + """ scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components).to(torch_device)