Skip to content

Commit

Permalink
[Single File] Add single file support for AutoencoderDC (huggingface#…
Browse files Browse the repository at this point in the history
…10183)

* update

* update

* update
  • Loading branch information
DN6 authored Dec 11, 2024
1 parent d041dd5 commit ad40e26
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 0 deletions.
20 changes: 20 additions & 0 deletions docs/source/en/api/models/autoencoder_dc.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ from diffusers import AutoencoderDC
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
```

## Load a model in Diffusers via `from_single_file`

```python
from difusers import AutoencoderDC

ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
model = AutoencoderDC.from_single_file(ckpt_path)

```

The `AutoencoderDC` model has `in` and `mix` single file checkpoint variants that have matching checkpoint keys, but use different scaling factors. It is not possible for Diffusers to automatically infer the correct config file to use with the model based on just the checkpoint and will default to configuring the model using the `mix` variant config file. To override the automatically determined config, please use the `config` argument when using single file loading with `in` variant checkpoints.

```python
from diffusers import AutoencoderDC

ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors"
model = AutoencoderDC.from_single_file(ckpt_path, config="mit-han-lab/dc-ae-f128c512-in-1.0-diffusers")
```


## AutoencoderDC

[[autodoc]] AutoencoderDC
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
Expand Down Expand Up @@ -82,6 +83,7 @@
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
}


Expand Down
95 changes: 95 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -138,6 +140,10 @@
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
}

# Use to configure model sample size when original config is provided
Expand Down Expand Up @@ -564,6 +570,23 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-dev"
else:
model_type = "flux-schnell"

elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
encoder_key = "encoder.project_in.conv.conv.bias"
decoder_key = "decoder.project_in.main.conv.weight"

if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
model_type = "autoencoder-dc-f32c32-sana"

elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
model_type = "autoencoder-dc-f32c32"

elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
model_type = "autoencoder-dc-f64c128"

else:
model_type = "autoencoder-dc-f128c512"

else:
model_type = "v1"

Expand Down Expand Up @@ -2198,3 +2221,75 @@ def swap_scale_shift(weight):
)

return converted_state_dict


def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}

def remap_qkv_(key: str, state_dict):
qkv = state_dict.pop(key)
q, k, v = torch.chunk(qkv, 3, dim=0)
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()

def remap_proj_conv_(key: str, state_dict):
parent_module, _, _ = key.rpartition(".proj.conv.weight")
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()

AE_KEYS_RENAME_DICT = {
# common
"main.": "",
"op_list.": "",
"context_module": "attn",
"local_module": "conv_out",
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
# If there were more scales, there would be more layers, so a loop would be better to handle this
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
"depth_conv.conv": "conv_depth",
"inverted_conv.conv": "conv_inverted",
"point_conv.conv": "conv_point",
"point_conv.norm": "norm",
"conv.conv.": "conv.",
"conv1.conv": "conv1",
"conv2.conv": "conv2",
"conv2.norm": "norm",
"proj.norm": "norm_out",
# encoder
"encoder.project_in.conv": "encoder.conv_in",
"encoder.project_out.0.conv": "encoder.conv_out",
"encoder.stages": "encoder.down_blocks",
# decoder
"decoder.project_in.conv": "decoder.conv_in",
"decoder.project_out.0": "decoder.norm_out",
"decoder.project_out.2.conv": "decoder.conv_out",
"decoder.stages": "decoder.up_blocks",
}

AE_F32C32_F64C128_F128C512_KEYS = {
"encoder.project_in.conv": "encoder.conv_in.conv",
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_SPECIAL_KEYS_REMAP = {
"qkv.conv.weight": remap_qkv_,
"proj.conv.weight": remap_proj_conv_,
}
if "encoder.project_in.conv.bias" not in converted_state_dict:
AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)

for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)

for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)

return converted_state_dict
126 changes: 126 additions & 0 deletions tests/single_file/test_model_autoencoder_dc_single_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import torch

from diffusers import (
AutoencoderDC,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device,
)


enable_full_determinism()


@slow
@require_torch_accelerator
class AutoencoderDCSingleFileTests(unittest.TestCase):
model_class = AutoencoderDC
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
main_input_name = "sample"
base_precision = 1e-2

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image

def test_single_file_inference_same_as_pretrained(self):
model_1 = self.model_class.from_pretrained(self.repo_id).to(torch_device)
model_2 = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id).to(torch_device)

image = self.get_sd_image(33)

with torch.no_grad():
sample_1 = model_1(image).sample
sample_2 = model_2(image).sample

assert sample_1.shape == sample_2.shape

output_slice_1 = sample_1.flatten().float().cpu()
output_slice_2 = sample_2.flatten().float().cpu()

assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4

def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
model_single_file = self.model_class.from_single_file(self.ckpt_path)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between pretrained loading and single file loading"

def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter
# in order to set the scaling factor correctly.
# `in` and `mix` variants have the same keys and we cannot automatically infer a scaling factor.
# We default to using teh `mix` config
repo_id = "mit-han-lab/dc-ae-f128c512-in-1.0-diffusers"
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors"

model = self.model_class.from_pretrained(repo_id)
model_single_file = self.model_class.from_single_file(ckpt_path, config=repo_id)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between pretrained loading and single file loading"

def test_single_file_mix_type_variant_components(self):
repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0/blob/main/model.safetensors"

model = self.model_class.from_pretrained(repo_id)
model_single_file = self.model_class.from_single_file(ckpt_path, config=repo_id)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between pretrained loading and single file loading"

0 comments on commit ad40e26

Please sign in to comment.