Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
strint committed Jul 9, 2024
2 parents f756010 + e9477f4 commit ee5fe70
Show file tree
Hide file tree
Showing 14 changed files with 370 additions and 45 deletions.
11 changes: 9 additions & 2 deletions benchmarks/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from diffusers.utils import load_image

from onediffx import compile_pipe, quantize_pipe # quantize_pipe currently only supports the nexfort backend.
from onediff.infer_compiler import oneflow_compile


def parse_args():
Expand Down Expand Up @@ -244,7 +245,13 @@ def main():
pass
elif args.compiler == "oneflow":
print("Oneflow backend is now active...")
pipe = compile_pipe(pipe)
# Note: The compile_pipe() based on the oneflow backend is incompatible with T5EncoderModel.
# pipe = compile_pipe(pipe)
if hasattr(pipe, "unet"):
pipe.unet = oneflow_compile(pipe.unet)
if hasattr(pipe, "transformer"):
pipe.transformer = oneflow_compile(pipe.transformer)
pipe.vae.decoder = oneflow_compile(pipe.vae.decoder)
elif args.compiler == "nexfort":
print("Nexfort backend is now active...")
if args.quantize:
Expand All @@ -267,7 +274,7 @@ def main():
options = json.loads(args.compiler_config)
else:
# config with string
options = '{"mode": "max-optimize:max-autotune:freezing", "memory_format": "channels_last"}'
options = '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}'
pipe = compile_pipe(
pipe, backend="nexfort", options=options, fuse_qkv_projections=True
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
{
"3": {
"inputs": {
"seed": 87631619688518,
"steps": 20,
"cfg": 8,
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"denoise": 1,
"model": [
"31",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"5": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "(blue colour lighting:1.3),photorealistic,masterpiece:1.5, spot light, exquisite gentle eyes,().\\n,portrait,masterpiece,breast focus:1.2,(),multicolored hair:1.4,wavy hair,3D face,black hair,short hair:1.2, sidelocks,1girl:1.3,blue eyes,tareme:1.5,(cowboy shot:1.5),(light smile:1.3),(stand:1.3),\\nhead tilt:1.3,(Shoulderless sundress:1.2),\\n(flat chest:1.4),cute face,(A balanced body,Model Body Type),\\n(Dark Background:1.3)、\\nslender Body:1.3,shiny hair, shiny skin,niji",
"clip": [
"31",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": " Multiple people,bad body,long body,(fat:1.2),long neck,deformed,mutated,malformed limbs,missing limb,acnes,skin spots,skin blemishes,poorly drawn face",
"clip": [
"31",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"31",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"31": {
"inputs": {
"ckpt_name": "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"3": {
"inputs": {
"seed": 87631619688518,
"steps": 20,
"cfg": 8,
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"denoise": 1,
"model": [
"31",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"5": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "(blue colour lighting:1.3),photorealistic,masterpiece:1.5, spot light, exquisite gentle eyes,().\\n,portrait,masterpiece,breast focus:1.2,(),multicolored hair:1.4,wavy hair,3D face,black hair,short hair:1.2, sidelocks,1girl:1.3,blue eyes,tareme:1.5,(cowboy shot:1.5),(light smile:1.3),(stand:1.3),\\nhead tilt:1.3,(Shoulderless sundress:1.2),\\n(flat chest:1.4),cute face,(A balanced body,Model Body Type),\\n(Dark Background:1.3)、\\nslender Body:1.3,shiny hair, shiny skin,niji",
"clip": [
"31",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": " Multiple people,bad body,long body,(fat:1.2),long neck,deformed,mutated,malformed limbs,missing limb,acnes,skin spots,skin blemishes,poorly drawn face",
"clip": [
"31",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"31",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"31": {
"inputs": {
"ckpt_name": "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
"vae_speedup": "disable"
},
"class_type": "OneDiffCheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint - OneDiff"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ python3 scripts/text_to_image.py \
python3 scripts/text_to_image.py \
--comfy-port $COMFY_PORT \
-w $WORKFLOW_DIR/lora_speedup.json $WORKFLOW_DIR/lora_multiple_speedup.json \
--baseline-dir $STANDARD_OUTPUT/test_lora_speedup
--baseline-dir $STANDARD_OUTPUT/test_lora_speedup \
--ssim-threshold 0.6

# # Baseline
# python3 scripts/text_to_image.py \
Expand All @@ -28,3 +29,9 @@ python3 scripts/text_to_image.py \
-w $WORKFLOW_DIR/ComfyUI_IPAdapter_plus/ipadapter_advanced.json \
--baseline-dir $STANDARD_OUTPUT/test_ipa
# --output-images \

python3 scripts/text_to_image.py \
--comfy-port $COMFY_PORT \
-w $WORKFLOW_DIR/txt2img.json \
--ssim-threshold 0.6 \
--baseline-dir $STANDARD_OUTPUT/txt2img/imgs # --output-images
4 changes: 3 additions & 1 deletion onediff_comfy_nodes/benchmarks/scripts/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def process_image(self, image_data: bytes, index: int) -> None:
baseline_image = Image.open(baseline_image_path)
ssim_value = calculate_ssim(pil_image, baseline_image)
self.logger.info(f"SSIM value with baseline: {ssim_value}")
assert ssim_value > self.ssim_threshold
assert (
ssim_value > self.ssim_threshold
), f"SSIM value {ssim_value} is not greater than the threshold {self.ssim_threshold}"


def run_workflow(
Expand Down
30 changes: 30 additions & 0 deletions onediff_comfy_nodes/benchmarks/src/input_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
WORKFLOW_DIR = "resources/workflows"
FACE_IMAGE_DIR = "/share_nfs/hf_models/comfyui_resources/input/faces"
POSE_IMAGE_DIR = "/share_nfs/hf_models/comfyui_resources/input/poses"
SDXL_MODELS = [
"Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
"Pony_Realism.safetensors",
"sdxl/dreamshaperXL_v21TurboDPMSDE.safetensors",
]
SD1_5_MODELS = [
"sd15/020.realisticVisionV51_v51VAE.safetensors",
"sd15/majicmixRealistic_v7.safetensors",
"sd15/v1-5-pruned-emaonly.ckpt",
"sd15/helloyoung25d_V10f.safetensors",
"sd15/RealCartoonSpecialPruned.safetensors",
]


class InputParams(NamedTuple):
graph: ComfyGraph
Expand Down Expand Up @@ -46,6 +59,23 @@ def _(workflow_path, *args, **kwargs):
yield InputParams(graph=graph)


@register_generator(
[f"{WORKFLOW_DIR}/baseline/txt2img.json", f"{WORKFLOW_DIR}/oneflow/txt2img.json"]
)
def _(workflow_path, *args, **kwargs):
with open(workflow_path, "r") as fp:
workflow = json.load(fp)
graph = ComfyGraph(graph=workflow, sampler_nodes=["3"])
for sdxl_model in SDXL_MODELS:
graph.set_image_size(height=1024, width=1024)
graph.graph["31"]["inputs"]["ckpt_name"] = sdxl_model
yield InputParams(graph)
for sd1_5_model in SD1_5_MODELS:
graph.set_image_size(height=768, width=512)
graph.graph["31"]["inputs"]["ckpt_name"] = sd1_5_model
yield InputParams(graph)


SD3_WORKFLOWS = [
f"{WORKFLOW_DIR}/baseline/sd3_baseline.json",
f"{WORKFLOW_DIR}/nexfort/sd3_unet_speedup.json",
Expand Down
30 changes: 11 additions & 19 deletions onediff_comfy_nodes/modules/booster_cache.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import torch
import traceback
from collections import OrderedDict
from comfy.model_patcher import ModelPatcher
from functools import singledispatch
from comfy.sd import VAE
from onediff.torch_utils.module_operations import get_sub_module
from onediff.utils.import_utils import is_oneflow_available
from .._config import is_disable_oneflow_backend


@singledispatch
Expand All @@ -16,14 +13,16 @@ def switch_to_cached_model(new_model, cached_model):

@switch_to_cached_model.register
def _(new_model: ModelPatcher, cached_model):
assert type(new_model.model) == type(
cached_model
), f"Model type mismatch: expected {type(cached_model)}, got {type(new_model.model)}"
for k, v in new_model.model.state_dict().items():
cached_v: torch.Tensor = get_sub_module(cached_model, k)
assert v.dtype == cached_v.dtype
cached_v.copy_(v)
new_model.model = cached_model
if type(new_model.model) != type(cached_model):
raise TypeError(
f"Model type mismatch: expected {type(cached_model)}, got {type(new_model.model)}"
)

cached_model.diffusion_model.load_state_dict(
new_model.model.diffusion_model.state_dict(), strict=True
)
new_model.model.diffusion_model = cached_model.diffusion_model
new_model.weight_inplace_update = True
return new_model


Expand All @@ -46,12 +45,6 @@ def get_cached_model(model):

@get_cached_model.register
def _(model: ModelPatcher):
if is_oneflow_available() and not is_disable_oneflow_backend():
from .oneflow.utils.booster_utils import is_using_oneflow_backend

if is_using_oneflow_backend(model):
return None

return model.model


Expand Down Expand Up @@ -83,8 +76,7 @@ def get_cached_model(self, key, model):
try:
return switch_to_cached_model(model, cached_model)
except Exception as e:
print("An exception occurred when switching to cached model:")
print(traceback.format_exc())
print(f"An exception occurred when switching to cached model:")
del self._cache[key]
torch.cuda.empty_cache()

Expand Down
1 change: 1 addition & 0 deletions onediff_comfy_nodes/modules/oneflow/booster_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _(self, model: ModelPatcher, ckpt_name: Optional[str] = None, **kwargs):
return model

compiled_model = oneflow_compile(torch_model)

model.model.diffusion_model = compiled_model

graph_file = generate_graph_path(f"{type(model).__name__}", model=model.model)
Expand Down
Loading

0 comments on commit ee5fe70

Please sign in to comment.