-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[lora]feat: use exclude modules to loraconfig. #11806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
802bd1a
5ecf05a
388a539
76356ea
04d6ddb
37fab37
eb8d78a
651b807
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None): | |
module.set_scale(adapter_name, 1.0) | ||
|
||
|
||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): | ||
def get_peft_kwargs( | ||
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None | ||
): | ||
rank_pattern = {} | ||
alpha_pattern = {} | ||
r = lora_alpha = list(rank_dict.values())[0] | ||
|
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True | |
else: | ||
lora_alpha = set(network_alpha_dict.values()).pop() | ||
|
||
# layer names without the Diffusers specific | ||
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) | ||
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) | ||
# for now we know that the "bias" keys are only associated with `lora_B`. | ||
|
@@ -195,6 +196,16 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True | |
"use_dora": use_dora, | ||
"lora_bias": lora_bias, | ||
} | ||
|
||
# Example: try load FusionX LoRA into Wan VACE | ||
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name) | ||
if exclude_modules: | ||
if not is_peft_version(">=", "0.14.0"): | ||
msg = "It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft` version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U peft`." | ||
logger.warning(msg) | ||
else: | ||
lora_config_kwargs.update({"exclude_modules": exclude_modules}) | ||
|
||
return lora_config_kwargs | ||
|
||
|
||
|
@@ -294,19 +305,20 @@ def check_peft_version(min_version: str) -> None: | |
|
||
|
||
def _create_lora_config( | ||
state_dict, | ||
network_alphas, | ||
metadata, | ||
rank_pattern_dict, | ||
is_unet: bool = True, | ||
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None | ||
): | ||
from peft import LoraConfig | ||
|
||
if metadata is not None: | ||
lora_config_kwargs = metadata | ||
else: | ||
lora_config_kwargs = get_peft_kwargs( | ||
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet | ||
rank_pattern_dict, | ||
network_alpha_dict=network_alphas, | ||
peft_state_dict=state_dict, | ||
is_unet=is_unet, | ||
model_state_dict=model_state_dict, | ||
adapter_name=adapter_name, | ||
) | ||
|
||
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) | ||
|
@@ -371,3 +383,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): | |
|
||
if warn_msg: | ||
logger.warning(warn_msg) | ||
|
||
|
||
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None): | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the | ||
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it | ||
doesn't exist in `peft_state_dict`. | ||
""" | ||
if model_state_dict is None: | ||
return | ||
all_modules = set() | ||
string_to_replace = f"{adapter_name}." if adapter_name else "" | ||
|
||
for name in model_state_dict.keys(): | ||
if string_to_replace: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe the if-statement is not needed here because string_to_replace will be an empty string, but no problem keeping as micro optimization |
||
name = name.replace(string_to_replace, "") | ||
if "." in name: | ||
module_name = name.rsplit(".", 1)[0] | ||
all_modules.add(module_name) | ||
|
||
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()} | ||
exclude_modules = list(all_modules - target_modules_set) | ||
|
||
return exclude_modules |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import copy | ||
import inspect | ||
import os | ||
import re | ||
|
@@ -290,6 +291,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): | |
|
||
return modules_to_save | ||
|
||
def _get_exclude_modules(self, pipe): | ||
from diffusers.utils.peft_utils import _derive_exclude_modules | ||
|
||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) | ||
denoiser = "unet" if self.unet_kwargs is not None else "transformer" | ||
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} | ||
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] | ||
pipe.unload_lora_weights() | ||
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() | ||
exclude_modules = _derive_exclude_modules( | ||
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" | ||
) | ||
return exclude_modules | ||
|
||
def check_if_adapters_added_correctly( | ||
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default" | ||
): | ||
|
@@ -2308,6 +2323,58 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): | |
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." | ||
) | ||
|
||
@require_peft_version_greater("0.13.2") | ||
def test_lora_exclude_modules(self): | ||
""" | ||
Test to check if `exclude_modules` works or not. It works in the following way: | ||
we first create a pipeline and insert LoRA config into it. We then derive a `set` | ||
of modules to exclude by investigating its denoiser state dict and denoiser LoRA | ||
state dict. | ||
|
||
We then create a new LoRA config to include the `exclude_modules` and perform tests. | ||
""" | ||
scheduler_cls = self.scheduler_classes[0] | ||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) | ||
pipe = self.pipeline_class(**components).to(torch_device) | ||
_, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
||
# only supported for `denoiser` now | ||
pipe_cp = copy.deepcopy(pipe) | ||
pipe_cp, _ = self.check_if_adapters_added_correctly( | ||
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config | ||
) | ||
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm afraid I don't fully understand how this test works but that shouldn't be a blocker. If this tests the exact same condition we're facing in the FusionX lora, then it should be good :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, not the cleanest test this one. I added a description in the test, see if that helps? Currently, we don't have:
I will think of a way to include a test for that, too. |
||
pipe_cp.to("cpu") | ||
del pipe_cp | ||
|
||
denoiser_lora_config.exclude_modules = denoiser_exclude_modules | ||
pipe, _ = self.check_if_adapters_added_correctly( | ||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config | ||
) | ||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) | ||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save) | ||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) | ||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) | ||
pipe.unload_lora_weights() | ||
pipe.load_lora_weights(tmpdir) | ||
|
||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
|
||
self.assertTrue( | ||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), | ||
"LoRA should change outputs.", | ||
) | ||
self.assertTrue( | ||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), | ||
"Lora outputs should match.", | ||
) | ||
|
||
def test_inference_load_delete_load_adapters(self): | ||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." | ||
for scheduler_cls in self.scheduler_classes: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering: The
exclude_modules
are derived from all modules - modules ofpeft_state_dict
, right? So in practice,exlcude_modules
would never be empty. Thus, users with old PEFT versions will always see this warning, even though in the vast majority of cases, there is nothing to worry about. I wonderexclude_modules
should just be ignored without any warning in that case.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or a
debug()
info with this message with an amendment saying "For majority use cases, this should be okay. But if you notice anything unexpected, please file an issue."There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.