diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index e968ef9628c2..6f8fb80fc298 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -198,46 +198,13 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif lora_name.startswith("lora_te_"): - diffusers_name = key.replace("lora_te_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - # (sayakpaul): Duplicate code. Needs to be cleaned. - elif lora_name.startswith("lora_te1_"): - diffusers_name = key.replace("lora_te1_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): + if lora_name.startswith(("lora_te_", "lora_te1_")): + key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" + else: + key_to_replace = "lora_te2_" - # (sayakpaul): Duplicate code. Needs to be cleaned. - elif lora_name.startswith("lora_te2_"): - diffusers_name = key.replace("lora_te2_", "").replace("_", ".") + diffusers_name = key.replace(key_to_replace, "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") diffusers_name = diffusers_name.replace("self.attn", "self_attn") diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") @@ -245,14 +212,22 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") if "self_attn" in diffusers_name: - te2_state_dict[diffusers_name] = state_dict.pop(key) - te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) elif "mlp" in diffusers_name: # Be aware that this is the new diffusers convention and the rest of the code might # not utilize it yet. diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te2_state_dict[diffusers_name] = state_dict.pop(key) - te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) # Rename the alphas so that they can be mapped appropriately. if lora_name_alpha in state_dict: