diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 9e02c49c7..abc6efa5c 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -231,6 +231,8 @@ jobs: run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_turbo.py --compile true --base /share_nfs/hf_models/sdxl-turbo - if: matrix.test-suite == 'diffusers_examples' run: docker exec -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 -m pytest -v onediff_diffusers_extensions/tests/test_lora.py + - if: matrix.test-suite == 'diffusers_examples' + run: docker exec -w /src/onediff/onediff_diffusers_extensions -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_reuse_pipe.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --new_base /share_nfs/hf_models/dataautogpt3-OpenDalleV1.1 - name: Shutdown docker for ComfyUI Test if: matrix.test-suite == 'comfy' diff --git a/onediff_diffusers_extensions/README.md b/onediff_diffusers_extensions/README.md index ed6c88eb4..fd2d308d1 100644 --- a/onediff_diffusers_extensions/README.md +++ b/onediff_diffusers_extensions/README.md @@ -430,6 +430,32 @@ We tested the performance of `set_adapters`, still using the five LoRA models me - While traversing the submodules of the model, we observed that the `getattr` time overhead of OneDiff's `DeployableModule` is high. Because the parameters of DeployableModule share the same address as the PyTorch module it wraps, we choose to traverse `DeployableModule._torch_module`, greatly improving traversal efficiency. +## Compiled graph re-using + +When switching models, if the new model has the same structure as the old model, you can re-use the previously compiled graph, which means you don't need to compile the new model again, which significantly reduces the time it takes you to switch models. + +Here is a pseudo code, to get detailed usage, please refer to [text_to_image_sdxl_reuse_pipe](./examples/text_to_image_sdxl_reuse_pipe.py): + +```python +base = StableDiffusionPipeline(...) +compiled_unet = oneflow_compile(base.unet) +base.unet = compiled_unet +# This step needs some time to compile the UNet +base(prompt) + +new_base = StableDiffusionPipeline(...) +# Re-use the compiled graph by loading the new state dict into the `_torch_module` member of the object returned by `oneflow_compile` +compiled_unet._torch_module.load_state_dict(new_base.unet.state_dict()) +# After loading the new state dict into the `compiled_unet._torch_module`, the weights of the compiled_unet are updated too +new_base.unet = compiled_unet +# This step doesn't need additional time to compile the UNet again because +# new_base.unet is already compiled +new_base(prompt) +``` + +> Note: Please make sure that your PyTorch version is **at least 2.1.0**, and set the environment variable `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to **0**. And the feature is not supported for quantized model. + + ## Quantization **Note**: Quantization feature is only supported by **OneDiff Enterprise**. diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py new file mode 100644 index 000000000..3fbbebd1d --- /dev/null +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -0,0 +1,190 @@ +import os +import argparse + +import torch + +from onediff.infer_compiler import oneflow_compile +from onediff.infer_compiler import oneflow_compiler_config +from onediff.schedulers import EulerDiscreteScheduler +from diffusers import StableDiffusionXLPipeline +# import diffusers +# diffusers.logging.set_verbosity_info() + +parser = argparse.ArgumentParser() +parser.add_argument( + "--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" +) +parser.add_argument( + "--new_base", type=str, default="dataautogpt3/OpenDalleV1.1", +) +parser.add_argument("--variant", type=str, default="fp16") +parser.add_argument( + "--prompt", + type=str, + default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", +) +parser.add_argument("--height", type=int, default=1024) +parser.add_argument("--width", type=int, default=1024) +parser.add_argument("--n_steps", type=int, default=30) +parser.add_argument("--guidance_scale", type=float, default=7.5) +parser.add_argument("--saved_image", type=str, required=False, default="sdxl-out.png") +parser.add_argument("--seed", type=int, default=1) +parser.add_argument( + "--compile_unet", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, +) +parser.add_argument( + "--compile_vae", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, +) +parser.add_argument( + "--run_multiple_resolutions", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, +) +args = parser.parse_args() + +# Normal SDXL pipeline init. +OUTPUT_TYPE = "pil" + +# SDXL base: StableDiffusionXLPipeline +scheduler = EulerDiscreteScheduler.from_pretrained(args.base, subfolder="scheduler") +base = StableDiffusionXLPipeline.from_pretrained( + args.base, + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, +) +base.to("cuda") + + +oneflow_compiler_config.mlir_enable_inference_optimization = False +# Compile unet with oneflow +if args.compile_unet: + print("Compiling unet with oneflow.") + compiled_unet = oneflow_compile(base.unet) + base.unet = compiled_unet + +# Compile vae with oneflow +if args.compile_vae: + print("Compiling vae with oneflow.") + compiled_decoder = oneflow_compile(base.vae.decoder) + base.vae.decoder = compiled_decoder + +# Warmup with run +# Will do compilatioin in the first run +print("Warmup with running graphs...") +torch.manual_seed(args.seed) +image = base( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.n_steps, + generator=torch.manual_seed(0), + output_type=OUTPUT_TYPE, + guidance_scale=args.guidance_scale, +).images +del base + +torch.cuda.empty_cache() + +print("loading new base") +if str(args.new_base).endswith(".safetensors"): + new_base = StableDiffusionXLPipeline.from_single_file( + args.new_base, + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, + ) +else: + new_base = StableDiffusionXLPipeline.from_pretrained( + args.new_base, + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, + ) +new_base.to("cuda") + +print("New base running by torch backend") +image = new_base( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.n_steps, + generator=torch.manual_seed(0), + output_type=OUTPUT_TYPE, + guidance_scale=args.guidance_scale, +).images +image[0].save(f"new_base_without_graph_h{args.height}-w{args.width}-{args.saved_image}") +image_eager = image[0] + + +# Update the unet and vae +# load_state_dict(state_dict, strict=True, assign=False), assign is False means copying them inplace into the module’s current parameters and buffers. +# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict +print("Loading state_dict of new base into compiled graph") +compiled_unet._torch_module.load_state_dict(new_base.unet.state_dict()) +compiled_decoder._torch_module.load_state_dict(new_base.vae.decoder.state_dict()) + +new_base.unet = compiled_unet +new_base.vae.decoder = compiled_decoder + +torch.cuda.empty_cache() + +# Normal SDXL run +print("Re-use the compiled graph") +image = new_base( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.n_steps, + generator=torch.manual_seed(0), + output_type=OUTPUT_TYPE, + guidance_scale=args.guidance_scale, +).images +image[0].save(f"new_base_reuse_graph_h{args.height}-w{args.width}-{args.saved_image}") +image_graph = image[0] + +from skimage.metrics import structural_similarity +import numpy as np + +ssim = structural_similarity( + np.array(image_eager), np.array(image_graph), channel_axis=-1, data_range=255 +) +print(f"ssim between naive torch and re-used graph is {ssim}") + + +# Should have no compilation for these new input shape +print("Test run with multiple resolutions...") +if args.run_multiple_resolutions: + sizes = [960, 720, 896, 768] + if "CI" in os.environ: + sizes = [360] + for h in sizes: + for w in sizes: + image = new_base( + prompt=args.prompt, + height=h, + width=w, + num_inference_steps=args.n_steps, + generator=torch.manual_seed(0), + output_type=OUTPUT_TYPE, + ).images + + +# print("Test run with other another uncommon resolution...") +# if args.run_multiple_resolutions: +# h = 544 +# w = 408 +# image = base( +# prompt=args.prompt, +# height=h, +# width=w, +# num_inference_steps=args.n_steps, +# output_type=OUTPUT_TYPE, +# ).images diff --git a/onediff_diffusers_extensions/setup.py b/onediff_diffusers_extensions/setup.py index ba6d36b82..a14e275f0 100644 --- a/onediff_diffusers_extensions/setup.py +++ b/onediff_diffusers_extensions/setup.py @@ -26,6 +26,7 @@ def get_version(): "accelerate", "torch", "onefx", + "omegaconf", ], classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/src/onediff/infer_compiler/with_oneflow_compile.py b/src/onediff/infer_compiler/with_oneflow_compile.py index d94bcc3c3..b3e4d2906 100644 --- a/src/onediff/infer_compiler/with_oneflow_compile.py +++ b/src/onediff/infer_compiler/with_oneflow_compile.py @@ -205,6 +205,7 @@ def __init__( get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module), ) object.__setattr__(self, "_modules", torch_module._modules) + object.__setattr__(self, "_torch_module", torch_module) self._deployable_module_use_graph = use_graph self._deployable_module_enable_dynamic = dynamic self._deployable_module_options = options