From ad40e265156e18964245ed943943fdeb7d8cf61a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 11 Dec 2024 16:57:36 +0530 Subject: [PATCH] [Single File] Add single file support for AutoencoderDC (#10183) * update * update * update --- docs/source/en/api/models/autoencoder_dc.md | 20 +++ src/diffusers/loaders/single_file_model.py | 2 + src/diffusers/loaders/single_file_utils.py | 95 +++++++++++++ .../test_model_autoencoder_dc_single_file.py | 126 ++++++++++++++++++ 4 files changed, 243 insertions(+) create mode 100644 tests/single_file/test_model_autoencoder_dc_single_file.py diff --git a/docs/source/en/api/models/autoencoder_dc.md b/docs/source/en/api/models/autoencoder_dc.md index f9931e099254..667f0de678f6 100644 --- a/docs/source/en/api/models/autoencoder_dc.md +++ b/docs/source/en/api/models/autoencoder_dc.md @@ -37,6 +37,26 @@ from diffusers import AutoencoderDC ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda") ``` +## Load a model in Diffusers via `from_single_file` + +```python +from difusers import AutoencoderDC + +ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors" +model = AutoencoderDC.from_single_file(ckpt_path) + +``` + +The `AutoencoderDC` model has `in` and `mix` single file checkpoint variants that have matching checkpoint keys, but use different scaling factors. It is not possible for Diffusers to automatically infer the correct config file to use with the model based on just the checkpoint and will default to configuring the model using the `mix` variant config file. To override the automatically determined config, please use the `config` argument when using single file loading with `in` variant checkpoints. + +```python +from diffusers import AutoencoderDC + +ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors" +model = AutoencoderDC.from_single_file(ckpt_path, config="mit-han-lab/dc-ae-f128c512-in-1.0-diffusers") +``` + + ## AutoencoderDC [[autodoc]] AutoencoderDC diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 0f01dd942734..d1fe55840aff 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -23,6 +23,7 @@ from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, + convert_autoencoder_dc_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_flux_transformer_checkpoint_to_diffusers, convert_ldm_unet_checkpoint, @@ -82,6 +83,7 @@ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers}, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 10742873ded1..8256e38054fe 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -92,6 +92,8 @@ "double_blocks.0.img_attn.norm.key_norm.scale", "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", ], + "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", + "autoencoder-dc-sana": "encoder.project_in.conv.bias", } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -138,6 +140,10 @@ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, + "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, + "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, + "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, + "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, } # Use to configure model sample size when original config is provided @@ -564,6 +570,23 @@ def infer_diffusers_model_type(checkpoint): model_type = "flux-dev" else: model_type = "flux-schnell" + + elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint: + encoder_key = "encoder.project_in.conv.conv.bias" + decoder_key = "decoder.project_in.main.conv.weight" + + if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint: + model_type = "autoencoder-dc-f32c32-sana" + + elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32: + model_type = "autoencoder-dc-f32c32" + + elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128: + model_type = "autoencoder-dc-f64c128" + + else: + model_type = "autoencoder-dc-f128c512" + else: model_type = "v1" @@ -2198,3 +2221,75 @@ def swap_scale_shift(weight): ) return converted_state_dict + + +def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + def remap_qkv_(key: str, state_dict): + qkv = state_dict.pop(key) + q, k, v = torch.chunk(qkv, 3, dim=0) + parent_module, _, _ = key.rpartition(".qkv.conv.weight") + state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() + state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() + state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() + + def remap_proj_conv_(key: str, state_dict): + parent_module, _, _ = key.rpartition(".proj.conv.weight") + state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() + + AE_KEYS_RENAME_DICT = { + # common + "main.": "", + "op_list.": "", + "context_module": "attn", + "local_module": "conv_out", + # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 + # If there were more scales, there would be more layers, so a loop would be better to handle this + "aggreg.0.0": "to_qkv_multiscale.0.proj_in", + "aggreg.0.1": "to_qkv_multiscale.0.proj_out", + "depth_conv.conv": "conv_depth", + "inverted_conv.conv": "conv_inverted", + "point_conv.conv": "conv_point", + "point_conv.norm": "norm", + "conv.conv.": "conv.", + "conv1.conv": "conv1", + "conv2.conv": "conv2", + "conv2.norm": "norm", + "proj.norm": "norm_out", + # encoder + "encoder.project_in.conv": "encoder.conv_in", + "encoder.project_out.0.conv": "encoder.conv_out", + "encoder.stages": "encoder.down_blocks", + # decoder + "decoder.project_in.conv": "decoder.conv_in", + "decoder.project_out.0": "decoder.norm_out", + "decoder.project_out.2.conv": "decoder.conv_out", + "decoder.stages": "decoder.up_blocks", + } + + AE_F32C32_F64C128_F128C512_KEYS = { + "encoder.project_in.conv": "encoder.conv_in.conv", + "decoder.project_out.2.conv": "decoder.conv_out.conv", + } + + AE_SPECIAL_KEYS_REMAP = { + "qkv.conv.weight": remap_qkv_, + "proj.conv.weight": remap_proj_conv_, + } + if "encoder.project_in.conv.bias" not in converted_state_dict: + AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS) + + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py new file mode 100644 index 000000000000..b1faeb78776b --- /dev/null +++ b/tests/single_file/test_model_autoencoder_dc_single_file.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import ( + AutoencoderDC, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + load_hf_numpy, + numpy_cosine_similarity_distance, + require_torch_accelerator, + slow, + torch_device, +) + + +enable_full_determinism() + + +@slow +@require_torch_accelerator +class AutoencoderDCSingleFileTests(unittest.TestCase): + model_class = AutoencoderDC + ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors" + repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers" + main_input_name = "sample" + base_precision = 1e-2 + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): + dtype = torch.float16 if fp16 else torch.float32 + image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) + return image + + def test_single_file_inference_same_as_pretrained(self): + model_1 = self.model_class.from_pretrained(self.repo_id).to(torch_device) + model_2 = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id).to(torch_device) + + image = self.get_sd_image(33) + + with torch.no_grad(): + sample_1 = model_1(image).sample + sample_2 = model_2(image).sample + + assert sample_1.shape == sample_2.shape + + output_slice_1 = sample_1.flatten().float().cpu() + output_slice_2 = sample_2.flatten().float().cpu() + + assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4 + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id) + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between pretrained loading and single file loading" + + def test_single_file_in_type_variant_components(self): + # `in` variant checkpoints require passing in a `config` parameter + # in order to set the scaling factor correctly. + # `in` and `mix` variants have the same keys and we cannot automatically infer a scaling factor. + # We default to using teh `mix` config + repo_id = "mit-han-lab/dc-ae-f128c512-in-1.0-diffusers" + ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors" + + model = self.model_class.from_pretrained(repo_id) + model_single_file = self.model_class.from_single_file(ckpt_path, config=repo_id) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between pretrained loading and single file loading" + + def test_single_file_mix_type_variant_components(self): + repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers" + ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0/blob/main/model.safetensors" + + model = self.model_class.from_pretrained(repo_id) + model_single_file = self.model_class.from_single_file(ckpt_path, config=repo_id) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between pretrained loading and single file loading"