diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/txt2img.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/txt2img.json new file mode 100644 index 000000000..61f402d28 --- /dev/null +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/txt2img.json @@ -0,0 +1,107 @@ +{ + "3": { + "inputs": { + "seed": 87631619688518, + "steps": 20, + "cfg": 8, + "sampler_name": "euler_ancestral", + "scheduler": "normal", + "denoise": 1, + "model": [ + "31", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "(blue colour lighting:1.3),photorealistic,masterpiece:1.5, spot light, exquisite gentle eyes,().\\n,portrait,masterpiece,breast focus:1.2,(),multicolored hair:1.4,wavy hair,3D face,black hair,short hair:1.2, sidelocks,1girl:1.3,blue eyes,tareme:1.5,(cowboy shot:1.5),(light smile:1.3),(stand:1.3),\\nhead tilt:1.3,(Shoulderless sundress:1.2),\\n(flat chest:1.4),cute face,(A balanced body,Model Body Type),\\n(Dark Background:1.3)、\\nslender Body:1.3,shiny hair, shiny skin,niji", + "clip": [ + "31", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": " Multiple people,bad body,long body,(fat:1.2),long neck,deformed,mutated,malformed limbs,missing limb,acnes,skin spots,skin blemishes,poorly drawn face", + "clip": [ + "31", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "31", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "31": { + "inputs": { + "ckpt_name": "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + } +} \ No newline at end of file diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/txt2img.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/txt2img.json new file mode 100644 index 000000000..bd7ce14a4 --- /dev/null +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/txt2img.json @@ -0,0 +1,108 @@ +{ + "3": { + "inputs": { + "seed": 87631619688518, + "steps": 20, + "cfg": 8, + "sampler_name": "euler_ancestral", + "scheduler": "normal", + "denoise": 1, + "model": [ + "31", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "(blue colour lighting:1.3),photorealistic,masterpiece:1.5, spot light, exquisite gentle eyes,().\\n,portrait,masterpiece,breast focus:1.2,(),multicolored hair:1.4,wavy hair,3D face,black hair,short hair:1.2, sidelocks,1girl:1.3,blue eyes,tareme:1.5,(cowboy shot:1.5),(light smile:1.3),(stand:1.3),\\nhead tilt:1.3,(Shoulderless sundress:1.2),\\n(flat chest:1.4),cute face,(A balanced body,Model Body Type),\\n(Dark Background:1.3)、\\nslender Body:1.3,shiny hair, shiny skin,niji", + "clip": [ + "31", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": " Multiple people,bad body,long body,(fat:1.2),long neck,deformed,mutated,malformed limbs,missing limb,acnes,skin spots,skin blemishes,poorly drawn face", + "clip": [ + "31", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "31", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "31": { + "inputs": { + "ckpt_name": "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors", + "vae_speedup": "disable" + }, + "class_type": "OneDiffCheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint - OneDiff" + } + } +} \ No newline at end of file diff --git a/onediff_comfy_nodes/benchmarks/scripts/run_oneflow_case_ci.sh b/onediff_comfy_nodes/benchmarks/scripts/run_oneflow_case_ci.sh index 29971e623..43ffe5e63 100644 --- a/onediff_comfy_nodes/benchmarks/scripts/run_oneflow_case_ci.sh +++ b/onediff_comfy_nodes/benchmarks/scripts/run_oneflow_case_ci.sh @@ -16,7 +16,8 @@ python3 scripts/text_to_image.py \ python3 scripts/text_to_image.py \ --comfy-port $COMFY_PORT \ -w $WORKFLOW_DIR/lora_speedup.json $WORKFLOW_DIR/lora_multiple_speedup.json \ - --baseline-dir $STANDARD_OUTPUT/test_lora_speedup + --baseline-dir $STANDARD_OUTPUT/test_lora_speedup \ + --ssim-threshold 0.6 # # Baseline # python3 scripts/text_to_image.py \ @@ -28,3 +29,9 @@ python3 scripts/text_to_image.py \ -w $WORKFLOW_DIR/ComfyUI_IPAdapter_plus/ipadapter_advanced.json \ --baseline-dir $STANDARD_OUTPUT/test_ipa # --output-images \ + +python3 scripts/text_to_image.py \ + --comfy-port $COMFY_PORT \ + -w $WORKFLOW_DIR/txt2img.json \ + --ssim-threshold 0.6 \ + --baseline-dir $STANDARD_OUTPUT/txt2img/imgs # --output-images diff --git a/onediff_comfy_nodes/benchmarks/scripts/text_to_image.py b/onediff_comfy_nodes/benchmarks/scripts/text_to_image.py index 8da40ec3f..3ddf1e1b9 100644 --- a/onediff_comfy_nodes/benchmarks/scripts/text_to_image.py +++ b/onediff_comfy_nodes/benchmarks/scripts/text_to_image.py @@ -93,7 +93,9 @@ def process_image(self, image_data: bytes, index: int) -> None: baseline_image = Image.open(baseline_image_path) ssim_value = calculate_ssim(pil_image, baseline_image) self.logger.info(f"SSIM value with baseline: {ssim_value}") - assert ssim_value > self.ssim_threshold + assert ( + ssim_value > self.ssim_threshold + ), f"SSIM value {ssim_value} is not greater than the threshold {self.ssim_threshold}" def run_workflow( diff --git a/onediff_comfy_nodes/benchmarks/src/input_registration.py b/onediff_comfy_nodes/benchmarks/src/input_registration.py index aa82b79d5..6d7e7b01d 100644 --- a/onediff_comfy_nodes/benchmarks/src/input_registration.py +++ b/onediff_comfy_nodes/benchmarks/src/input_registration.py @@ -7,6 +7,19 @@ WORKFLOW_DIR = "resources/workflows" FACE_IMAGE_DIR = "/share_nfs/hf_models/comfyui_resources/input/faces" POSE_IMAGE_DIR = "/share_nfs/hf_models/comfyui_resources/input/poses" +SDXL_MODELS = [ + "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors", + "Pony_Realism.safetensors", + "sdxl/dreamshaperXL_v21TurboDPMSDE.safetensors", +] +SD1_5_MODELS = [ + "sd15/020.realisticVisionV51_v51VAE.safetensors", + "sd15/majicmixRealistic_v7.safetensors", + "sd15/v1-5-pruned-emaonly.ckpt", + "sd15/helloyoung25d_V10f.safetensors", + "sd15/RealCartoonSpecialPruned.safetensors", +] + class InputParams(NamedTuple): graph: ComfyGraph @@ -46,6 +59,23 @@ def _(workflow_path, *args, **kwargs): yield InputParams(graph=graph) +@register_generator( + [f"{WORKFLOW_DIR}/baseline/txt2img.json", f"{WORKFLOW_DIR}/oneflow/txt2img.json"] +) +def _(workflow_path, *args, **kwargs): + with open(workflow_path, "r") as fp: + workflow = json.load(fp) + graph = ComfyGraph(graph=workflow, sampler_nodes=["3"]) + for sdxl_model in SDXL_MODELS: + graph.set_image_size(height=1024, width=1024) + graph.graph["31"]["inputs"]["ckpt_name"] = sdxl_model + yield InputParams(graph) + for sd1_5_model in SD1_5_MODELS: + graph.set_image_size(height=768, width=512) + graph.graph["31"]["inputs"]["ckpt_name"] = sd1_5_model + yield InputParams(graph) + + SD3_WORKFLOWS = [ f"{WORKFLOW_DIR}/baseline/sd3_baseline.json", f"{WORKFLOW_DIR}/nexfort/sd3_unet_speedup.json", diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py index 6d2f307dc..b086d9bf6 100644 --- a/onediff_comfy_nodes/modules/booster_cache.py +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -1,12 +1,9 @@ import torch -import traceback from collections import OrderedDict from comfy.model_patcher import ModelPatcher from functools import singledispatch from comfy.sd import VAE from onediff.torch_utils.module_operations import get_sub_module -from onediff.utils.import_utils import is_oneflow_available -from .._config import is_disable_oneflow_backend @singledispatch @@ -16,14 +13,16 @@ def switch_to_cached_model(new_model, cached_model): @switch_to_cached_model.register def _(new_model: ModelPatcher, cached_model): - assert type(new_model.model) == type( - cached_model - ), f"Model type mismatch: expected {type(cached_model)}, got {type(new_model.model)}" - for k, v in new_model.model.state_dict().items(): - cached_v: torch.Tensor = get_sub_module(cached_model, k) - assert v.dtype == cached_v.dtype - cached_v.copy_(v) - new_model.model = cached_model + if type(new_model.model) != type(cached_model): + raise TypeError( + f"Model type mismatch: expected {type(cached_model)}, got {type(new_model.model)}" + ) + + cached_model.diffusion_model.load_state_dict( + new_model.model.diffusion_model.state_dict(), strict=True + ) + new_model.model.diffusion_model = cached_model.diffusion_model + new_model.weight_inplace_update = True return new_model @@ -46,12 +45,6 @@ def get_cached_model(model): @get_cached_model.register def _(model: ModelPatcher): - if is_oneflow_available() and not is_disable_oneflow_backend(): - from .oneflow.utils.booster_utils import is_using_oneflow_backend - - if is_using_oneflow_backend(model): - return None - return model.model @@ -83,8 +76,7 @@ def get_cached_model(self, key, model): try: return switch_to_cached_model(model, cached_model) except Exception as e: - print("An exception occurred when switching to cached model:") - print(traceback.format_exc()) + print(f"An exception occurred when switching to cached model:") del self._cache[key] torch.cuda.empty_cache() diff --git a/onediff_comfy_nodes/modules/oneflow/booster_basic.py b/onediff_comfy_nodes/modules/oneflow/booster_basic.py index 9db31327b..78a3bc798 100644 --- a/onediff_comfy_nodes/modules/oneflow/booster_basic.py +++ b/onediff_comfy_nodes/modules/oneflow/booster_basic.py @@ -42,6 +42,7 @@ def _(self, model: ModelPatcher, ckpt_name: Optional[str] = None, **kwargs): return model compiled_model = oneflow_compile(torch_model) + model.model.diffusion_model = compiled_model graph_file = generate_graph_path(f"{type(model).__name__}", model=model.model) diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_utils.py b/onediff_comfy_nodes/modules/oneflow/hijack_utils.py index 4a4f25c5a..b2e8034e5 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_utils.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_utils.py @@ -1,14 +1,19 @@ """hijack ComfyUI/comfy/utils.py""" import torch from comfy.utils import copy_to_param +from onediff.infer_compiler.backends.oneflow.param_utils import ( + update_graph_related_tensor, +) from ..sd_hijack_utils import Hijacker +@torch.no_grad() def copy_to_param_of(org_fn, obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) if prev.data.dtype == torch.int8 and prev.data.dtype != value.dtype: @@ -16,6 +21,9 @@ def copy_to_param_of(org_fn, obj, attr, value): prev.data.copy_(value) + if isinstance(obj, torch.nn.Conv2d): + update_graph_related_tensor(obj) + def cond_func(orig_func, *args, **kwargs): return True diff --git a/src/onediff/infer_compiler/backends/oneflow/deployable_module.py b/src/onediff/infer_compiler/backends/oneflow/deployable_module.py index c09883261..1fe7b8099 100644 --- a/src/onediff/infer_compiler/backends/oneflow/deployable_module.py +++ b/src/onediff/infer_compiler/backends/oneflow/deployable_module.py @@ -15,7 +15,12 @@ from .dual_module import DualModule, get_mixed_dual_module from .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled from .args_tree_util import input_output_processor -from .param_utils import parse_device, check_device, generate_constant_folding_info +from .param_utils import ( + parse_device, + check_device, + generate_constant_folding_info, + update_graph_with_constant_folding_info, +) from .graph_management_utils import graph_file_management from .online_quantization_utils import quantize_and_deploy_wrapper from ..env_var import OneflowCompileOptions @@ -195,6 +200,7 @@ def load_graph(self, file_path, device=None, run_warmup=True, *, state_dict=None file_path, device, run_warmup, state_dict=state_dict ) generate_constant_folding_info(self) + update_graph_with_constant_folding_info(self) def save_graph(self, file_path, *, process_state_dict=lambda x: x): self.get_graph().save_graph(file_path, process_state_dict=process_state_dict) diff --git a/src/onediff/infer_compiler/backends/oneflow/param_utils.py b/src/onediff/infer_compiler/backends/oneflow/param_utils.py index c5f53440f..f972686e7 100644 --- a/src/onediff/infer_compiler/backends/oneflow/param_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/param_utils.py @@ -1,4 +1,5 @@ import re +import types import torch import oneflow as flow from typing import List, Dict, Any, Union @@ -108,9 +109,26 @@ def convert_var_name(s: str, prefix="variable_transpose_"): for k, v in zip(*graph._c_nn_graph.get_runtime_var_states()) if k.startswith("variable_transpose_") and v.ndim == 4 } + setattr(deployable_module, CONSTANT_FOLDING_INFO_ATTR, result) + set_constant_folded_conv_attr(deployable_module, result) + def make_custom_copy_(module): + def custom_copy_(self, src, non_blocking=False): + torch.Tensor.copy_(self, src, non_blocking) + # Update graph related tensors + update_graph_related_tensor(module) + + return custom_copy_ + + from onediff.torch_utils.module_operations import get_sub_module + + torch_model: torch.nn.Module = deployable_module._torch_module + for k in result.keys(): + module = get_sub_module(torch_model, removesuffix(k, ".weight")) + module.weight.copy_ = types.MethodType(make_custom_copy_(module), module.weight) + def update_graph_with_constant_folding_info( module: torch.nn.Module, info: Dict[str, flow.Tensor] = None diff --git a/tests/comfy-docker-compose.yml b/tests/comfy-docker-compose.yml index 9731618bd..99f4d7f4d 100644 --- a/tests/comfy-docker-compose.yml +++ b/tests/comfy-docker-compose.yml @@ -26,8 +26,6 @@ services: HF_HUB_OFFLINE: "1" ONEFLOW_MLIR_ENABLE_TIMING: "1" ONEFLOW_MLIR_PRINT_STATS: "1" - ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION: "0" - COMFYUI_ONEDIFF_SAVE_GRAPH_DIR: "/share_nfs/hf_models/comfyui_resources/input" CI: "1" SILICON_ONEDIFF_LICENSE_KEY: ${SILICON_ONEDIFF_LICENSE_KEY} volumes: