Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_nunchaku_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
Expand Down Expand Up @@ -99,6 +100,18 @@
else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_nunchaku_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_nunchaku_objects

_import_structure["utils.dummy_nunchaku_objects"] = [
name for name in dir(dummy_nunchaku_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("NunchakuConfig")

try:
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -791,6 +804,14 @@
else:
from .quantizers.quantization_config import QuantoConfig

try:
if not is_nunchaku_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_optimum_quanto_objects import *
else:
from .quantizers.quantization_config import NunchakuConfig

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down
56 changes: 54 additions & 2 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..quantizers.quantization_config import NunchakuConfig
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
Expand All @@ -42,6 +43,7 @@
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_nunchaku_flux_to_diffusers,
convert_sana_transformer_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
Expand Down Expand Up @@ -190,6 +192,23 @@ def _get_mapping_function_kwargs(mapping_fn, **kwargs):
return mapping_kwargs


def _maybe_determine_modules_to_not_convert(quantization_config, state_dict):
if quantization_config is None:
return None
else:
is_nunchaku = quantization_config.quant_method == "nunchaku"
if not is_nunchaku:
return None
else:
no_qweight = set()
for key in state_dict:
if key.endswith(".weight"):
# module name is everything except the last piece after "."
module_name = ".".join(key.split(".")[:-1])
no_qweight.add(module_name)
return sorted(no_qweight)


class FromOriginalModelMixin:
"""
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
Expand Down Expand Up @@ -404,8 +423,14 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
model = cls.from_config(diffusers_model_config)

checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)

if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
model_state_dict = model.state_dict()
# TODO: Only flux nunchaku checkpoint for now. Unify with how checkpoint mappers are done.
# For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`.
if quantization_config is not None and quantization_config.quant_method == "nunchaku":
diffusers_format_checkpoint = convert_nunchaku_flux_to_diffusers(
checkpoint, model_state_dict=model_state_dict
)
elif _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
Expand All @@ -416,6 +441,27 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)

# This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect.
# We can move it to a separate function as well.
if quantization_config is not None:
original_modules_to_not_convert = quantization_config.modules_to_not_convert or []
determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert(
quantization_config, checkpoint
)
if determined_modules_to_not_convert:
determined_modules_to_not_convert.extend(original_modules_to_not_convert)
determined_modules_to_not_convert = list(set(determined_modules_to_not_convert))
logger.debug(
f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}."
)
modified_quant_config = quantization_config.to_dict()
modified_quant_config["modules_to_not_convert"] = determined_modules_to_not_convert
# TODO: figure out a better way.
modified_quant_config = NunchakuConfig.from_dict(modified_quant_config)
setattr(hf_quantizer, "quantization_config", modified_quant_config)
logger.debug("TODO")

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
Expand Down Expand Up @@ -443,6 +489,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
for k in unexpected_keys:
if "single_transformer_blocks.0" in k:
print(f"Unexpected {k=}")
for k in empty_state_dict:
if "single_transformer_blocks.0" in k:
print(f"model {k=}")
device_map = {"": param_device}
load_model_dict_into_meta(
model,
Expand Down
99 changes: 99 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,6 +2189,105 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
return converted_state_dict


# Adapted from https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395
def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs):
from .single_file_utils_nunchaku import _unpack_qkv_state_dict

_SMOOTH_ORIG_RE = re.compile(r"\.smooth_orig(\.|$)")
_SMOOTH_RE = re.compile(r"\.smooth(\.|$)")

new_state_dict = {}
model_state_dict = kwargs["model_state_dict"]

ckpt_keys = list(checkpoint.keys())
for k in ckpt_keys:
if "qweight" in k:
# only the shape information of this tensor is needed
v = checkpoint[k]
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
for t in ["lora_up", "lora_down"]:
new_k = k.replace(".qweight", f".{t}")
if new_k not in ckpt_keys:
oc, ic = v.shape
ic = ic * 2 # v is packed into INT8, so we need to double the size
checkpoint[k.replace(".qweight", f".{t}")] = torch.zeros(
(0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
)

for k, v in checkpoint.items():
new_k = k # start with original, then apply independent replacements

if k.startswith("single_transformer_blocks."):
# attention / qkv / norms
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
new_k = new_k.replace(".out_proj.", ".proj_out.")
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")

# mlp heads
new_k = new_k.replace(".mlp_fc1.", ".proj_mlp.")
new_k = new_k.replace(".mlp_fc2.", ".proj_out.")

# smooth params (use regex to avoid substring collisions)
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)

# lora -> proj
new_k = new_k.replace(".lora_down", ".proj_down")
new_k = new_k.replace(".lora_up", ".proj_up")

elif k.startswith("transformer_blocks."):
# feed-forward (context & base)
new_k = new_k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.")
new_k = new_k.replace(".mlp_context_fc2.", ".ff_context.net.2.")
new_k = new_k.replace(".mlp_fc1.", ".ff.net.0.proj.")
new_k = new_k.replace(".mlp_fc2.", ".ff.net.2.")

# attention projections
new_k = new_k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.")
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
new_k = new_k.replace(".out_proj.", ".attn.to_out.0.")
new_k = new_k.replace(".out_proj_context.", ".attn.to_add_out.")

# norms
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
new_k = new_k.replace(".norm_added_k.", ".attn.norm_added_k.")
new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.")

# smooth params
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)

# lora -> proj
new_k = new_k.replace(".lora_down", ".proj_down")
new_k = new_k.replace(".lora_up", ".proj_up")

new_state_dict[new_k] = v

new_state_dict = _unpack_qkv_state_dict(new_state_dict)

# some remnant keys need to be patched
new_sd_keys = list(new_state_dict.keys())
for k in new_sd_keys:
if "qweight" in k:
no_qweight_k = ".".join(k.split(".qweight")[:-1])
for unexpected_k in ["wzeros"]:
unexpected_k = no_qweight_k + f".{unexpected_k}"
if unexpected_k in new_sd_keys:
_ = new_state_dict.pop(unexpected_k)
for k in model_state_dict:
if k not in new_state_dict:
# CPU device for now
new_state_dict[k] = torch.ones_like(model_state_dict[k], device="cpu")

for k in new_state_dict:
if "single_transformer_blocks.0" in k and k.endswith(".weight"):
print(f"{k=}")

return new_state_dict


def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
Expand Down
104 changes: 104 additions & 0 deletions src/diffusers/loaders/single_file_utils_nunchaku.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import re

import torch


_QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj")
_ALLOWED_SUFFIXES_NUNCHAKU = {
"bias",
"proj_down",
"proj_up",
"qweight",
"smooth_factor",
"smooth_factor_orig",
"wscales",
}

_QKV_NUNCHAKU_REGEX = re.compile(
rf"^(?P<prefix>.*)\.(?:{'|'.join(map(re.escape, _QKV_ANCHORS_NUNCHAKU))})\.(?P<suffix>.+)$"
)


def _pick_split_dim(t: torch.Tensor, suffix: str) -> int:
"""
Choose which dimension to split by 3. Heuristics:
- 1D -> dim 0
- 2D -> prefer dim=1 for 'qweight' (common layout [*, 3*out_features]),
otherwise prefer dim=0 (common layout [3*out_features, *]).
- If preferred dim isn't divisible by 3, try the other; else error.
"""
shape = list(t.shape)
if len(shape) == 0:
raise ValueError("Cannot split a scalar into Q/K/V.")

if len(shape) == 1:
dim = 0
if shape[dim] % 3 == 0:
return dim
raise ValueError(f"1D tensor of length {shape[0]} not divisible by 3.")

# len(shape) >= 2
preferred = 1 if suffix == "qweight" else 0
other = 0 if preferred == 1 else 1

if shape[preferred] % 3 == 0:
return preferred
if shape[other] % 3 == 0:
return other

# Fall back: any dim divisible by 3
for d, s in enumerate(shape):
if s % 3 == 0:
return d

raise ValueError(f"None of the dims {shape} are divisible by 3 for suffix '{suffix}'.")


def _split_qkv(t: torch.Tensor, dim: int):
return torch.tensor_split(t, 3, dim=dim)


def _unpack_qkv_state_dict(
state_dict: dict, anchors=_QKV_ANCHORS_NUNCHAKU, allowed_suffixes=_ALLOWED_SUFFIXES_NUNCHAKU
):
"""
Convert fused QKV entries (e.g., '...to_qkv.bias', '...qkv_proj.wscales') into separate Q/K/V entries:
'...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales'
Returns a NEW dict; original is not modified.

Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.:
"""
anchors = tuple(anchors)
allowed_suffixes = set(allowed_suffixes)

new_sd: dict = {}
sd_keys = list(state_dict.keys())
for k in sd_keys:
m = _QKV_NUNCHAKU_REGEX.match(k)
v = state_dict.pop(k)
if m:
suffix = m.group("suffix")
if suffix not in allowed_suffixes:
# keep as-is if it's not one of the targeted suffixes
new_sd[k] = v
continue

prefix = m.group("prefix") # everything before .to_qkv/.qkv_proj
# Decide split axis
split_dim = _pick_split_dim(v, suffix)
q, k_, vv = _split_qkv(v, dim=split_dim)

# Build new keys
base_q = f"{prefix}.to_q.{suffix}"
base_k = f"{prefix}.to_k.{suffix}"
base_v = f"{prefix}.to_v.{suffix}"

# Write into result dict
new_sd[base_q] = q
new_sd[base_k] = k_
new_sd[base_v] = vv
else:
# not a fused qkv key
new_sd[k] = v

return new_sd
7 changes: 7 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,13 @@ def load_model_dict_into_meta(
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
# This check below might be a bit counter-intuitive in nature. This is because we're checking if the param
# or its module is quantized and if so, we're proceeding with creating a quantized param. This is because
# of the way pre-trained models are loaded. They're initialized under "meta" device, where
# quantization layers are first injected. Hence, for a model that is either pre-quantized or supplemented
# with a `quantization_config` during `from_pretrained`, we expect `check_if_quantized_param` to return True.
# Then depending on the quantization backend being used, we run the actual quantization step under
# `create_quantized_param`.
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .nunchaku import NunchakuQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
NunchakuConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
Expand All @@ -39,6 +41,7 @@
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"nunchaku": NunchakuQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -47,12 +50,13 @@
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"nunchaku": NunchakuConfig,
}


class DiffusersAutoQuantizer:
"""
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
`DiffusersQuantizer` given the `QuantizationConfig`.
"""

Expand Down
Loading
Loading