(Results on Stable Diffusion v1.5. Left: 50 PLMS steps. Right: 2.3x acceleration upon 50 PLMS steps)
DeepCache: Accelerating Diffusion Models for Free
Xinyin Ma, Gongfan Fang, Xinchao Wang
Learning and Vision Lab, National University of Singapore
🥯[Arxiv]🎄[Project Page]
- 🚀 Training-free and almost lossless
- 🚀 Support Stable Diffusion, Stable Diffusion XL, Stable Video Diffusion, DDPM
- 🚀 Compatible with sampling algorithms like DDIM and PLMS
- December 21, 2023 Release the code for Stable Video Diffusion. The upper line shows the original videos, and the below line is accelerated by DeepCache.
-
December 20, 2023: Release the code for DDPM. See here for the experimental code and instructions.
-
December 6, 2023: Release the code for Stable Diffusion XL. The results of the
stabilityai/stable-diffusion-xl-base-1.0
are shown in the below figure, with the same prompts from the first figure.
We introduce DeepCache, a novel training-free and almost lossless paradigm that accelerates diffusion models from the perspective of model architecture. Utilizing the property of the U-Net, we reuse the high-level features while updating the low-level features in a very cheap way. DeepCache accelerates Stable Diffusion v1.5 by 2.3x with only a 0.05 decline in CLIP Score, and LDM-4-G(ImageNet) by 4.1x with a 0.22 decrease in FID.
pip install transformers diffusers
python stable_diffusion_xl.py --model stabilityai/stable-diffusion-xl-base-1.0
Output:
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.62it/s]
2023-12-06 01:44:28,578 - INFO - Running baseline...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00, 2.93it/s]
2023-12-06 01:44:46,095 - INFO - Baseline: 17.52 seconds
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 8.06it/s]
2023-12-06 01:45:02,865 - INFO - Running DeepCache...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:06<00:00, 8.01it/s]
2023-12-06 01:45:09,573 - INFO - DeepCache: 6.71 seconds
2023-12-06 01:45:10,678 - INFO - Saved to output.png. Done!
You can add --refine
at the end of the command to activate the refiner model for SDXL.
python stable_diffusion.py --model runwayml/stable-diffusion-v1-5
Output:
2023-12-03 16:18:13,636 - INFO - Loaded safety_checker as StableDiffusionSafetyChecker from `safety_checker` subfolder of runwayml/stable-diffusion-v1-5.
2023-12-03 16:18:13,699 - INFO - Loaded vae as AutoencoderKL from `vae` subfolder of runwayml/stable-diffusion-v1-5.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 5.88it/s]
2023-12-03 16:18:22,837 - INFO - Running baseline...
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:03<00:00, 15.33it/s]
2023-12-03 16:18:26,174 - INFO - Baseline: 3.34 seconds
2023-12-03 16:18:26,174 - INFO - Running DeepCache...
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 34.06it/s]
2023-12-03 16:18:27,718 - INFO - DeepCache: 1.54 seconds
2023-12-03 16:18:27,935 - INFO - Saved to output.png. Done!
python stable_diffusion.py --model stabilityai/stable-diffusion-2-1
Output:
2023-12-03 16:21:17,858 - INFO - Loaded feature_extractor as CLIPImageProcessor from `feature_extractor` subfolder of stabilityai/stable-diffusion-2-1.
2023-12-03 16:21:17,864 - INFO - Loaded scheduler as DDIMScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-2-1.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████| 6/6 [00:01<00:00, 5.35it/s]
2023-12-03 16:21:49,770 - INFO - Running baseline...
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:14<00:00, 3.42it/s]
2023-12-03 16:22:04,551 - INFO - Baseline: 14.78 seconds
2023-12-03 16:22:04,551 - INFO - Running DeepCache...
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:08<00:00, 6.10it/s]
2023-12-03 16:22:12,911 - INFO - DeepCache: 8.36 seconds
2023-12-03 16:22:13,417 - INFO - Saved to output.png. Done!
Currently, our code supports the models that can be loaded by StableDiffusionPipeline. You can specify the model name by the argument --model
, which by default, is runwayml/stable-diffusion-v1-5
.
python stable_video_diffusion.py
Output:
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 8.36it/s]
2023-12-21 04:56:47,329 - INFO - Running baseline...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [01:27<00:00, 3.49s/it]
2023-12-21 04:58:26,121 - INFO - Origin: 98.66 seconds
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.59it/s]
2023-12-21 04:58:27,202 - INFO - Running DeepCache...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:49<00:00, 1.96s/it]
2023-12-21 04:59:26,607 - INFO - DeepCache: 59.31 seconds
Please check here for the experimental code of DDPM. The code for LDM will be released soon.
import torch
from DeepCache import StableDiffusionXLPipeline as DeepCacheStableDiffusionXLPipeline
pipe = DeepCacheStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda:0")
prompt = "A photo of a cat. Focus light and create sharp, defined edges."
deepcache_output = pipe(
prompt,
cache_interval=3, cache_layer_id=0, cache_block_id=0,
output_type='pt', return_dict=True
).images
import torch
from DeepCache import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to("cuda:0")
prompt = "a photo of an astronaut on a moon"
deepcache_output = pipe(
prompt,
cache_interval=5, cache_layer_id=0, cache_block_id=0,
uniform=True, #pow=1.4, center=15, # only for uniform = False
output_type='pt', return_dict=True
).images
Arguments:
- cache_interval: the interval (N in the 1:N strategy) of cache update. Larger intervals bring more significant speedup.
- cache_layer_id & cache_block_id: the block/layer ID of the selected skip branch.
- uniform: whether to enable the uniform caching strategy.
- pow & center: the hyperparameters for non-uniform 1:N strategy.
import torch
from diffusers.utils import load_image, export_to_video
from DeepCache.svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline as DeepCacheStableVideoDiffusionPipeline
deepcache_pipe = DeepCacheStableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16",
)
deepcache_pipe.enable_model_cpu_offload()
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
image = image.resize((1024, 576))
generator = torch.manual_seed(42)
frames = deepcache_pipe(
image,
decode_chunk_size=8, generator=generator,
cache_interval=3, cache_branch=0,
).frames[0]
Arguments:
- cache_interval: the interval (N in the 1:N strategy) of cache update. Larger intervals bring more significant speedup, but would impair the quality of the video.
- cache_branch: the selected skip branch. Select
cache_branch
from 0-10
Images in the upper line are the baselines, and the images in the lower line are accelerated by DeepCache.
More results can be found in our paper
We sincerely thank the authors listed below who implemented DeepCache in plugins or other contexts.
- OneDiff Integration: https://github.com/Oneflow-Inc/onediff?tab=readme-ov-file#easy-to-use by @Oneflow-Inc. OneDiff also has implementations for DeepCache on SVD, check this for details.
- Comfyui: https://gist.github.com/laksjdjf/435c512bc19636e9c9af4ee7bea9eb86 by @laksjdjf
- Colab & Gradio: https://github.com/camenduru/DeepCache-colab by @camenduru
- WebUI: AUTOMATIC1111/stable-diffusion-webui#14210 by @aria1th
We warmly welcome contributions from everyone. Please feel free to reach out to us.
@article{ma2023deepcache,
title={DeepCache: Accelerating Diffusion Models for Free},
author={Ma, Xinyin and Fang, Gongfan and Wang, Xinchao},
journal={arXiv preprint arXiv:2312.00858},
year={2023}
}