Skip to content

Commit

Permalink
fix lora for ci
Browse files Browse the repository at this point in the history
  • Loading branch information
ccssu committed Jul 8, 2024
1 parent ce16117 commit 563f8cd
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ python3 scripts/text_to_image.py \
python3 scripts/text_to_image.py \
--comfy-port $COMFY_PORT \
-w $WORKFLOW_DIR/lora_speedup.json $WORKFLOW_DIR/lora_multiple_speedup.json \
--baseline-dir $STANDARD_OUTPUT/test_lora_speedup
--baseline-dir $STANDARD_OUTPUT/test_lora_speedup \
--ssim-threshold 0.8

# # Baseline
# python3 scripts/text_to_image.py \
Expand Down
10 changes: 5 additions & 5 deletions onediff_comfy_nodes/modules/booster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def _(new_model: ModelPatcher, cached_model):
assert type(new_model.model) == type(
cached_model
), f"Model type mismatch: expected {type(cached_model)}, got {type(new_model.model)}"
for k, v in new_model.model.state_dict().items():
cached_v: torch.Tensor = get_sub_module(cached_model, k)
assert v.dtype == cached_v.dtype
cached_v.copy_(v)
new_model.model = cached_model
cached_model.diffusion_model.load_state_dict(
new_model.model.diffusion_model.state_dict(), strict=True
)
new_model.model.diffusion_model = cached_model.diffusion_model
new_model.weight_inplace_update = True
return new_model


Expand Down
8 changes: 8 additions & 0 deletions onediff_comfy_nodes/modules/oneflow/hijack_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
"""hijack ComfyUI/comfy/utils.py"""
import torch
from comfy.utils import copy_to_param
from onediff.infer_compiler.backends.oneflow.param_utils import (
update_graph_related_tensor,
)
from ..sd_hijack_utils import Hijacker


@torch.no_grad()
def copy_to_param_of(org_fn, obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)

prev = getattr(obj, attrs[-1])

if prev.data.dtype == torch.int8 and prev.data.dtype != value.dtype:
return

prev.data.copy_(value)

if isinstance(obj, torch.nn.Conv2d):
update_graph_related_tensor(obj)


def cond_func(orig_func, *args, **kwargs):
return True
Expand Down

0 comments on commit 563f8cd

Please sign in to comment.