diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 211f9da61947..4ade3374d80e 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -244,13 +244,20 @@ def load_lora_adapter( k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys } - # create LoraConfig - lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank) - # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) + # create LoraConfig + lora_config = _create_lora_config( + state_dict, + network_alphas, + metadata, + rank, + model_state_dict=self.state_dict(), + adapter_name=adapter_name, + ) + # =", "0.14.0"): + msg = """ +It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft` +version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U +peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue - +https://github.com/huggingface/diffusers/issues/new + """ + logger.debug(msg) + else: + lora_config_kwargs.update({"exclude_modules": exclude_modules}) + return lora_config_kwargs @@ -294,11 +310,7 @@ def check_peft_version(min_version: str) -> None: def _create_lora_config( - state_dict, - network_alphas, - metadata, - rank_pattern_dict, - is_unet: bool = True, + state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None ): from peft import LoraConfig @@ -306,7 +318,12 @@ def _create_lora_config( lora_config_kwargs = metadata else: lora_config_kwargs = get_peft_kwargs( - rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet + rank_pattern_dict, + network_alpha_dict=network_alphas, + peft_state_dict=state_dict, + is_unet=is_unet, + model_state_dict=model_state_dict, + adapter_name=adapter_name, ) _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) @@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): if warn_msg: logger.warning(warn_msg) + + +def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None): + """ + Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the + `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it + doesn't exist in `peft_state_dict`. + """ + if model_state_dict is None: + return + all_modules = set() + string_to_replace = f"{adapter_name}." if adapter_name else "" + + for name in model_state_dict.keys(): + if string_to_replace: + name = name.replace(string_to_replace, "") + if "." in name: + module_name = name.rsplit(".", 1)[0] + all_modules.add(module_name) + + target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} + exclude_modules = list(all_modules - target_modules_set) + + return exclude_modules diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 95ec44b2bf41..fe26a56e77cf 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -24,7 +24,11 @@ WanPipeline, WanTransformer3DModel, ) -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps +from diffusers.utils.testing_utils import ( + floats_tensor, + require_peft_backend, + skip_mps, +) sys.path.append(".") diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 8180f92245a0..91ca188137e7 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import os import re @@ -291,6 +292,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): return modules_to_save + def _get_exclude_modules(self, pipe): + from diffusers.utils.peft_utils import _derive_exclude_modules + + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + denoiser = "unet" if self.unet_kwargs is not None else "transformer" + modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} + denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] + pipe.unload_lora_weights() + denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() + exclude_modules = _derive_exclude_modules( + denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" + ) + return exclude_modules + def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"): if text_lora_config is not None: if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -2326,6 +2341,58 @@ def test_lora_unload_add_adapter(self): ) _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + @require_peft_version_greater("0.13.2") + def test_lora_exclude_modules(self): + """ + Test to check if `exclude_modules` works or not. It works in the following way: + we first create a pipeline and insert LoRA config into it. We then derive a `set` + of modules to exclude by investigating its denoiser state dict and denoiser LoRA + state dict. + + We then create a new LoRA config to include the `exclude_modules` and perform tests. + """ + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + # only supported for `denoiser` now + pipe_cp = copy.deepcopy(pipe) + pipe_cp, _ = self.add_adapters_to_pipeline( + pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + denoiser_exclude_modules = self._get_exclude_modules(pipe_cp) + pipe_cp.to("cpu") + del pipe_cp + + denoiser_lora_config.exclude_modules = denoiser_exclude_modules + pipe, _ = self.add_adapters_to_pipeline( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdir) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), + "LoRA should change outputs.", + ) + self.assertTrue( + np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), + "Lora outputs should match.", + ) + def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." for scheduler_cls in self.scheduler_classes: