diff --git a/.gitignore b/.gitignore index 4226ce31..d157d7ec 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ dist/ .clangd CMakeUserPresets.json tmp_autoround/ +outputs/ +models/ldm/stable-diffusion-v1/*.ckpt diff --git a/auto_round_diff/__init__.py b/auto_round_diff/__init__.py new file mode 100644 index 00000000..0f70962f --- /dev/null +++ b/auto_round_diff/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .autoround import AutoRound, AutoRoundAdam, AutoRoundOPT +from auto_round.utils import LazyImport + +def __getattr__(name): + if name == 'AutoHfQuantizer': + from auto_round.inference.auto_quantizer import AutoHfQuantizer + return AutoHfQuantizer + if name == 'AutoRoundConfig': + from auto_round.inference.auto_quantizer import AutoRoundConfig + return AutoRoundConfig + + raise AttributeError(f"auto-round has no attribute '{name}'") + +from .version import __version__ diff --git a/auto_round_diff/__main__.py b/auto_round_diff/__main__.py new file mode 100644 index 00000000..b7558f83 --- /dev/null +++ b/auto_round_diff/__main__.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append('.') +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '2' + +# def run_eval(): +# from auto_round.script.llm import setup_eval_parser +# args = setup_eval_parser() +# if args.eval_task_by_task: +# from auto_round.script.llm import eval_task_by_task +# eval_task_by_task( +# model=args.model, +# device=args.device, +# tasks=args.tasks, +# batch_size=args.eval_bs, +# trust_remote_code=not args.disable_trust_remote_code, +# eval_model_dtype=args.eval_model_dtype +# ) +# else: +# from auto_round.script.llm import eval +# eval(args) + + +# def run(): +# if "--eval" in sys.argv: +# sys.argv.remove("--eval") +# run_eval() +# else: +# from auto_round.script.llm import setup_parser, tune +# args = setup_parser() +# tune(args) + +# def run_best(): +# from auto_round.script.llm import setup_best_parser, tune +# args = setup_best_parser() +# tune(args) + +# def run_light(): +# from auto_round.script.llm import setup_light_parser, tune +# args = setup_light_parser() +# tune(args) + +# def run_fast(): +# from auto_round.script.llm import setup_fast_parser, tune +# args = setup_fast_parser() +# tune(args) + +def run_diffusion(): + if "--eval" in sys.argv: + from auto_round_diff.script.diffusion import setup_lmeval_parser, eval + sys.argv.remove("--eval") + args = setup_lmeval_parser() + eval(args) + elif "--sample" in sys.argv: + pass + else: + from auto_round_diff.script.diffusion import setup_parser, tune + args = setup_parser() + tune(args) + +def switch(): + if "--dm" in sys.argv: + sys.argv.remove("--dm") + run_diffusion() + else: + pass + # run() + +if __name__ == '__main__': + switch() + diff --git a/auto_round_diff/autoround.py b/auto_round_diff/autoround.py new file mode 100644 index 00000000..d1f411ca --- /dev/null +++ b/auto_round_diff/autoround.py @@ -0,0 +1,5796 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import sys + +import torch +import copy +import time +from typing import Optional, Union, List +from transformers import set_seed +from torch import autocast +from tqdm import tqdm +import accelerate +from .wrapper import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer +from .utils import ( + CpuInfo, + block_forward, + check_is_cpu, + check_to_quantized, + collect_best_params, + convert_dtype_str2torch, + detect_device, + get_block_names, + get_module, + htcore, + is_optimum_habana_available, + logger, + to_device, + to_dtype, + get_layer_names_in_block, + mv_module_from_gpu, + unsupport_meta_device, clear_memory, + compile_func, + find_matching_blocks, is_debug_mode, + TORCH_VERSION_AT_LEAST_2_6, + supported_layer_types, + get_layer_features, + set_module, + llm_load_model, + reset_params, + init_cache, check_skippable_keywords, get_shared_keys, supported_dtypes, infer_bits_by_data_type, +) + + +class AutoRound(object): + """For more information, please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent + for the quantization of llms." arXiv preprint arXiv:2309.05516 (2023). + + Args: + model: The PyTorch model to be quantized. + tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied. + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether symmetric quantization is to be used (default is True). + layer_config (dict): Configuration for weight quantization (default is None). + layer_config={ + 'layer1':##layer_name + { + 'data_type': 'int', + 'bits': 4, + 'group_size': 128, + 'sym': True + 'act_data_type': None, + 'act_bits': 16, + 'act_group_size': None, + 'act_sym': None, + + } + ... + } + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for tuning (default is "auto"). + lr_scheduler: The learning rate scheduler to be used. + dataset (str): The default dataset name (default is "partiprompts"). + enable_quanted_input (bool): Whether to use the output of the previous quantized block as + the input for the current block (default is True). + enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True). + lr (float): The learning rate (default is None, will be set to 1.0/iters). + minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). + low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). + iters (int): Number of iterations (default is 200). + seqlen (int): Data length of the sequence for tuning (default is 2048). + nsamples (int): Number of samples (default is 128). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + nblocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + have different choices. + act_bits (int): Number of bits for activation quantization. Default is 16. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. + to_quant_block_names (str|list): A string or list whose elements are list of + block's layer names to be quantized. + enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning + enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. + device_map (str|dict): device map for each block + Returns: + The quantized model. + """ + + def __init__( + self, + model: torch.nn.Module, + tokenizer, + bits: int = 4, + group_size: int = 128, + sym: bool = True, + layer_config: dict = None, + batch_size: int = 8, + amp: bool = True, + device: str = None, + lr_scheduler=None, + dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", + enable_quanted_input: bool = True, + enable_minmax_tuning: bool = True, + lr: float = None, + minmax_lr: float = None, + low_gpu_mem_usage: bool = False, + low_cpu_mem_usage: bool = False, + iters: int = 200, + seqlen: int = 2048, + nsamples: int = 128, + sampler: str = "rand", + seed: int = 42, + nblocks: int = 1, + gradient_accumulate_steps: int = 1, + not_use_best_mse: bool = False, + dynamic_max_gap: int = -1, + data_type: str = "int", + scale_dtype: str = "fp16", + act_bits: int = 16, + act_group_size: int = None, + act_sym: bool = None, + act_data_type: str = None, + act_dynamic: bool = True, + to_quant_block_names: Union[str, list] = None, + enable_norm_bias_tuning: bool = False, + enable_torch_compile: bool = False, + device_map: Union[str, dict] = None, + super_bits: int = None, + super_group_size: int = None, + model_kwargs: dict = None, + **kwargs, + ): + self.quantized = False + self.model_orig_dtype = model.dtype + self.seed = seed + set_seed(self.seed) + assert not unsupport_meta_device(model), ( + "AutoRound does not support for params on meta device." + " Please use more gpus by setting `--device 0,1,2,3` or just use one gpu") + + ## important tuning hype-parameters + self.amp = amp + self.enable_quanted_input = enable_quanted_input + self.enable_minmax_tuning = enable_minmax_tuning + self.nsamples = nsamples + self.bits = bits + self.enable_norm_bias_tuning = enable_norm_bias_tuning + self.group_size = group_size + self.sym = sym + + self.low_gpu_mem_usage = low_gpu_mem_usage + self.low_cpu_mem_usage = low_cpu_mem_usage + self.layer_config = {} if layer_config is None else layer_config + self.seqlen = seqlen + self.batch_size, self.gradient_accumulate_steps = batch_size, gradient_accumulate_steps + self.nblocks = nblocks + self.dataset = dataset + self.iters = iters + if self.iters < 0: + logger.warning("`iters` must be non-negative, reset it to 200") + self.iters = 200 + if self.iters == 0: + self.lr = 5e-3 + else: + self.lr = lr or (1.0 / self.iters) ##must after iter setting + self.minmax_lr = minmax_lr or self.lr + + self.data_type = data_type + tmp_bits = infer_bits_by_data_type(self.data_type) + if tmp_bits<16 and tmp_bits!=bits: + logger.warning( + f"bits set in 'data_type' do not match the specified 'bits' setting. Resetting 'bits' to {tmp_bits}.") + self.bits = tmp_bits + self.supported_types = supported_layer_types + self.model = model.eval() + self.tokenizer = tokenizer + self.device = detect_device(device) + self.scale_dtype = convert_dtype_str2torch(scale_dtype) + self.set_amp_dtype() + self.to_quant_block_names = to_quant_block_names + if not hasattr(self, 'quant_block_list'): + all_blocks = get_block_names(model) + self.quant_block_list = find_matching_blocks(model, all_blocks, self.to_quant_block_names) + self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device + + ##activation + self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size + self.act_bits = act_bits if not (act_bits is None) else self.bits + self.act_sym = act_sym if not (act_sym is None) else self.sym + self.act_dynamic = act_dynamic + self.act_data_type = act_data_type + if self.act_data_type is None: + if data_type in supported_dtypes and self.act_bits <= 16: + self.act_data_type = data_type + logger.info(f"activation adopts {data_type}") + else: + self.act_data_type = "float" + + tmp_act_bits = infer_bits_by_data_type(self.act_data_type) + if tmp_act_bits < 16: + self.act_bits = tmp_act_bits + + self.sampler = sampler + self.not_use_best_mse = not_use_best_mse + self.dynamic_max_gap = dynamic_max_gap + self.lr_scheduler = lr_scheduler + self.optimizer = self.get_optimizer(None) + self.batch_dim = None + self.infer_bs_coeff = 1 + + self.super_bits = super_bits + self.super_group_size = super_group_size + + torch.set_printoptions(precision=3, sci_mode=True) + self.check_configs() + if self.act_bits <= 8 and self.amp_dtype == torch.float16: + logger.warning("force to use bf16 to for quantization tuning when enabling activation quantization") + self.amp_dtype = torch.bfloat16 + self.model = self.model.to(torch.bfloat16) + else: + logger.info(f"using {self.model.dtype} for quantization tuning") + + self.enable_torch_compile = enable_torch_compile + if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \ + and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type: + logger.info("'enable_torch_compile' is set to `False` by default. " \ + "Enabling it can reduce tuning cost by 20%, but it might throw an exception.") + + if self.act_bits <= 8 and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled") + + if self.low_cpu_mem_usage == True and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled") + + if is_debug_mode() and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as debug mode is enabled") + + if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") + + if is_optimum_habana_available(): + logger.info("Optimum Habana is available, import htcore explicitly.") + import habana_frameworks.torch.core as htcore # pylint: disable=E0401 + import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] + self.device_map = device_map + + self.set_device_map_in_blocks(self.device_map) + + self.is_packing_immediate = False ## whether to pack the layer immediately after tuning + + self.serialization_keys = [ + "bits", + "group_size", + "sym", + "data_type", + "enable_quanted_input", + "enable_minmax_tuning", + "data_type", + "seqlen", + "batch_size", + "scale_dtype", + "lr", + "minmax_lr", + "gradient_accumulate_steps", + "iters", + "amp", + "nsamples", + "low_gpu_mem_usage", + "to_quant_block_names", + "enable_norm_bias_tuning", + "act_bits", + "act_group_size", + "act_sym", + "act_dynamic", + "act_data_type", + "super_bits", + "super_group_size" + ] + + self.has_qlayer_outside_block = self.set_layerwise_config(self.layer_config) ##better place in the end + self.shared_cache_keys = get_shared_keys(self.model) + + def set_device_map_in_blocks(self, device_map): + """Sets the device map for specific blocks in the model. + + Args: + device_map (Union[str, dict]): A mapping of module names to devices. + If provided as a string, it should be in the format + "module_name:device,module_name:device". Devices can be integers + (GPU IDs) or strings (e.g., 'cpu', 'cuda:0'). + """ + if self.device_map is None or len(self.device_map) == 0: + self.device_map = None + if not device_map: + return + if isinstance(device_map, str): + device_map = device_map.replace(" ", "") + infos = device_map.split(",") + device_map_dict = {} + for info in infos: + index = info.find(':') + key = info[:index] + value = info[index + 1:] + device_map_dict[key] = value + device_map = device_map_dict + + names = [n for n, m in self.model.named_modules() if len(list(m.children())) == 0] + + for key, device in device_map.items(): + if isinstance(device, str) and device.isdigit(): + device = int(device) + device = detect_device(device) + try: + module = get_module(self.model, key) + module.tuning_device = device + except: + matching_names = [name for name in names if re.match(key, name)] + if len(matching_names) > 0: + for name in matching_names: + self._set_device_for_matching_module(name, device) + else: + for name in names: + if key in name: + self._set_device_for_matching_module(name, device) + + def _set_device_for_matching_module(self, name, device): + module = get_module(self.model, name) + if hasattr(module, "tuning_device") and module.tuning_device != device: + logger.warning( + f"Multiple devices have been set for layer {name}, keeping original device {module.tuning_device}") + else: + module.tuning_device = device + + def _dq_check(self): + """Reset the default value of super_bits and super_group_size""" + from auto_round.export.export_to_gguf.config import GGUF_CONFIG + if self.data_type.endswith("_dq"): + gguf_config = GGUF_CONFIG[f"gguf:q{self.bits}_k_s"] + self.super_bits = gguf_config["super_bits"] if self.super_bits is None else self.super_bits + self.super_group_size = gguf_config["super_group_size"] \ + if self.super_group_size is None else self.super_group_size + + def check_configs(self): + + """Checks if the configurations are valid. + + Raises: + AssertionError: If any of the configurations are invalid. + """ + assert isinstance(self.model, torch.nn.Module) + assert self.bits > 0, "bits must be positive" + assert self.act_bits > 0, "bits must be positive" + assert self.group_size == -1 or self.group_size >= 1, "only supports positive group_size or -1(per channel)" + assert self.act_group_size == -1 or self.act_group_size >= 1, \ + "only supports positive group_size or -1(per channel)" + assert self.batch_size > 0, "batch size must be positive" + assert self.iters >= 0, "iters must be non-negative" + assert self.seqlen > 0, "seqlen must be positive" + assert self.nblocks > 0, "nblocks must be positive" + assert self.gradient_accumulate_steps > 0, "gradient accumulate step must be positive" + # assert self.tokenizer != None or self.dataloader != None + if self.act_bits <= 8: + logger.warning( + "activation quantization is an experimental feature with limited support and a complex API. " + "And please save the quantized model to fake format as real deployment is not supported currently") + + if "mx_fp" in self.data_type: + logger.warning( + "please save the quantized model to fake format " + "as real deployment is not supported for mx_fp datatype currently") + + if "mx_fp" in self.data_type and self.group_size != 32: + logger.warning("mx_fp should only support group_size of 32 in real deployment") + + if self.nsamples < self.gradient_accumulate_steps * self.batch_size: + if self.batch_size > self.nsamples: + logger.warning(f"reset batch_size to {self.nsamples} as nsamples({self.nsamples})" + f" is smaller than batch_size({self.batch_size})") + self.batch_size = self.nsamples + if self.gradient_accumulate_steps > self.nsamples // self.batch_size: + self.gradient_accumulate_steps = self.nsamples // self.batch_size + logger.warning( + f"reset gradient_accumulate_steps to {self.gradient_accumulate_steps}" + f" as nsamples must equal or greater" + f" than gradient_accumulate_steps * batch_size") + self._dq_check() + + # def _check_format_compatibility(self, format): ##TODO + # ##check lm_head, mixed_bits, bits, each layer supporting, etc + # pass + + def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "auto_round", inplace=True, **kwargs): + """Quantizes the model and saves it in the specified format(s). + + This function checks the validity of the requested format(s), quantizes + the model accordingly, and saves it to the specified output directory. + If multiple formats are provided, the model is saved separately for each format. + + Args: + output_dir (str, optional): The directory where the quantized model + will be saved. Defaults to "tmp_autoround". + format (str, optional): The quantization format(s) to use, separated + by commas if multiple. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place if only + one format is used. Defaults to True. + **kwargs: Additional arguments for the quantization and saving process. + + Returns: + model: A qdq model or packed model based on the configurations + folders: The folder paths where the quantized models are saved. + + Raises: + ValueError: If an unsupported format is specified. + """ + # Validate and process the specified formats + formats = format.replace(' ', '').split(',') + from auto_round.utils import supported_formats + for format_ in formats: + if format_ not in supported_formats: + logger.error(f"Unsupported format {format_}, please choose from {supported_formats}") + exit(-1) + + # only support to export afp8 + if self.act_bits <= 8: + if "fp8" not in self.act_data_type: + if len(formats) > 1 or "fake" not in formats: + logger.warning( + f"Currently only support to export auto_round format quantized model" + " with fp8 dtype activation for activation quantization." + " Change format to fake and save." + ) + formats = ["fake"] + else: + if len(formats) > 1 or "auto_round" not in formats: + logger.warning( + f"Currently only support to export auto_round format for W{self.bits}AFP8 model," + " change format to auto_round" + ) + formats = ["auto_round"] + + # If multiple formats are specified, enforce inplace=False + if len(formats) > 1: + inplace = False + inplace = kwargs.get("inplace", inplace) + kwargs.pop("inplace", None) + + # Determine if immediate packing is required + if (len(formats) == 1 and + ("awq" in formats[0] or "gptq" in formats[0] or "auto_round" in formats[0]) and + not self.has_qlayer_outside_block and inplace): # TODO: Support more formats + self.is_packing_immediate = True + + # Adjust format settings based on compatibility + for index in range(len(formats)): + format = formats[index] + if "auto_round" in format: + if (self.sym and ("gptq" not in format and "awq" not in format)) or self.bits == 3: + format = format.replace('auto_round', 'auto_round:auto_gptq') + formats[index] = format + + # Remove duplicates from formats list + def remove_duplicates(lst): + seen = set() + return [x for x in lst if not (x in seen or seen.add(x))] + + formats = remove_duplicates(formats) + self.formats = formats + + # # Check format compatibility + # self._check_format_compatibility(formats) + + # Perform model quantization + model, _ = self.quantize() + + # Save the quantized model in the specified formats + folders = [] + for format in formats: + if "gptq" in format and not self.sym: + logger.warning( + "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/4/8 bits)." + ) + save_format_ = format.replace(":", "-").replace("_", "-") + save_folder = os.path.join(output_dir, save_format_) if len(formats) > 1 else output_dir + self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs) + + folders.append(save_folder) + + return model, folders + + @torch.inference_mode + def quantize_rtn(self): + if self.amp: + self.model.to(self.amp_dtype) + self.model.to("cpu") + all_to_quantized_module_names = [] + for n, m in self.model.named_modules(): + if check_to_quantized(m): + all_to_quantized_module_names.append(n) + pbar = tqdm(all_to_quantized_module_names) + + for name in pbar: + pbar.set_description(f"Quantizing {name}") + m = get_module(self.model, name) + + m.to(self.device) + m = WrapperLinear(m, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False) + m = m.unwrapper({}) + m.to("cpu") + if self.low_gpu_mem_usage: + clear_memory() + if self.is_packing_immediate: + from auto_round.export import PACKING_LAYER_WITH_FORMAT + if check_to_quantized(m): + target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0] + PACKING_LAYER_WITH_FORMAT[target_backend](name, self.model, self.formats[0]) + if self.low_gpu_mem_usage: + clear_memory() + else: + set_module(self.model, name, m) + + self.quantized = True + return self.model, self.layer_config + + def quantize(self): + """Quantize the model and return the quantized model along with layer configurations. + the entry of AutoRound. + + Returns: + The quantized model and layer configurations. + """ + if self.iters == 0: + return self.quantize_rtn() + + if bool(self.quant_block_list): + all_blocks = self.quant_block_list + else: + all_blocks = get_block_names(self.model) + + if len(all_blocks) == 0: + logger.warning("could not find blocks, exit with original model") + return self.model, self.layer_config + + if self.amp: + self.model = self.model.to(self.amp_dtype) + + layer_names = self.get_quantized_layer_names_outside_blocks() + self.start_time = time.time() + all_first_block_names = [block[0] for block in all_blocks] + logger.info("start to cache block inputs") + all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules(self.model) ##self.model.hf_device_map has not been changed + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + logger.info("caching done") + pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) + + for block_names in all_blocks: + inputs = all_inputs[block_names[0]] + all_inputs.pop(block_names[0]) + keys = inputs.keys() + input_id_str = [key for key in keys if key.startswith('hidden_state')] + if len(input_id_str) != 1: + raise RuntimeError(f"hidden_states arg mismatch error," + "please raise an issue in https://github.com/intel/auto-round/issues") + inputs["input_ids"] = inputs.pop(input_id_str[0], None) + clear_memory(self.inputs) + + if "input_ids" in inputs.keys(): + total_samples = len(inputs["input_ids"]) + if total_samples < self.batch_size: + self.batch_size = total_samples + logger.warning(f"force the train batch size to {total_samples}") + + self.quant_blocks( + self.model, + inputs, + block_names, + nblocks=self.nblocks, + device=self.device, + pbar=pbar + ) + if self.is_packing_immediate: + assert len(self.formats) == 1 + + self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately + + end_time = time.time() + cost_time = end_time - self.start_time + logger.info(f"quantization tuning time {cost_time}") + + ## dump a summary + quantized_layers = [] + unquantized_layers = [] + for n, m in self.model.named_modules(): + if isinstance(m, tuple(self.supported_types)): + if check_to_quantized(m): + quantized_layers.append(n) + else: + unquantized_layers.append(n) + elif hasattr(m, "scales") or hasattr(m, "scale"): ##packing_immediately + quantized_layers.append(n) + summary_info = ( + f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model" + ) + if len(unquantized_layers) > 0: + summary_info += f", {unquantized_layers} have not been quantized" + logger.info(summary_info) + + self.quantized = True + return self.model, self.layer_config + + def quant_layers(self, layer_names, layer_inputs): + """Quantizes specified layers based on inputs and configuration. + + Args: + layer_names (list): List of layer names to quantize. + layer_inputs (dict): Dictionary mapping layer names to input data. + + Returns: + None + """ + ##TODO currently we take all the layers outside blocks as post block layers which is not optimal + if len(layer_names) == 0: + return + q_layer_inputs = None + enable_quanted_input = self.enable_quanted_input + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1 and enable_quanted_input: + from accelerate.big_modeling import dispatch_model + + dispatch_model(self.model, self.model.hf_device_map) + + if enable_quanted_input: + q_layer_inputs = self.try_cache_inter_data_gpucpu([], self.nsamples, layer_names=layer_names) + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules( + self.model) ##self.model.hf_device_map has not been changed + + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + if self.enable_torch_compile: + quant_layer = compile_func(self.quant_layer, self.device) + else: + quant_layer = self.quant_layer + for layer_name in layer_names: + layer_input = layer_inputs[layer_name] + layer_input = to_device(layer_input, self.cache_device) + q_layer_input = q_layer_inputs[layer_name] if enable_quanted_input else None + q_layer_input = to_device(q_layer_input, self.cache_device) + quant_layer(layer_name, layer_input, q_layer_input, device=self.device) + del layer_input + clear_memory(q_layer_input) + + def set_layerwise_config(self, layer_config): + """ + Sets the layer-wise configuration based on the provided `layer_config`. + By default, only quantize layers in blocks. + + Args: + layer_config (dict): The configuration dictionary for each layer containing various configuration options. + + Returns: + bool: Returns True if there are quantized layers outside the blocks (e.g., lm-head), + otherwise returns False. + """ + # Get the names of layers in quantization blocks + layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list) + + ##process regex in layer_config + all_supported_layer_names = [] + # List of configuration keys + keys = self.serialization_keys + + for n, m in self.model.named_modules(): + # Delete previous configuration to avoid conflicts with prior tuning + for key in keys: + if hasattr(m, key): + delattr(m, key) + + # Skip unsupported types + if not isinstance(m, tuple(self.supported_types)): + continue + all_supported_layer_names.append(n) + + names_in_layer_config = list(layer_config.keys()) + for name in names_in_layer_config: + if name in all_supported_layer_names: + continue + matched_names = [] + for layer_name in all_supported_layer_names: + if re.search(re.compile(name), layer_name) is not None: + matched_names.append(layer_name) + if len(matched_names) > 0: + val = layer_config[name] + layer_config.pop(name) + for match_name in matched_names: + layer_config[match_name] = val + else: + raise ValueError(f"key {name} in layer_config is invalid, please have a double check") + + has_qlayer_outside_block = False # Flag to track if there are quantized layers outside blocks (e.g., lm-head) + + # Iterate through all modules in the model + for n, m in self.model.named_modules(): + + # Skip unsupported types + if not isinstance(m, tuple(self.supported_types)): + continue + + # If the layer is not in the config and is part of a quantization block, use default configuration + if n not in layer_config.keys() and n in layers_in_blocks: + layer_config[n] = {} + for key in keys: + layer_config[n][key] = getattr(self, key) + # If the layer is partially configured, fill in missing values + elif n in layer_config.keys(): + for key in keys: + if key not in layer_config[n].keys(): + layer_config[n][key] = getattr(self, key) + # If the layer is not in the config and not part of a quantization block, + # use default configuration and set specific values + else: + layer_config[n] = {} + for key in keys: + layer_config[n][key] = getattr(self, key) + layer_config[n]["bits"] = 16 + layer_config[n]["act_bits"] = 16 + + if n in layers_in_blocks: + layer_config[n]["in_blocks"] = True + else: + layer_config[n]["in_blocks"] = False + + # If the layer is outside a block and requires quantization, mark it as a quantized layer outside the block + if n not in layers_in_blocks and check_to_quantized(layer_config[n]): + has_qlayer_outside_block = True + + in_features, out_features = get_layer_features(m) + if in_features <= layer_config[n]["group_size"]: + layer_config[n]["group_size"] = -1 + + # Apply the configuration to the corresponding layer in the model + for key in keys: + setattr(m, key, layer_config[n][key]) + + # Return whether there are quantized layers outside the blocks + return has_qlayer_outside_block + + @torch.no_grad() + def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device, save_output=True): + """Compute the output of a given block of the model for a given input. + + Args: + block: The block of the model. + input_ids: The input tensor containing tokenized input ids. + input_others: A dictionary containing additional input data. + bs: The batch size for computing the output. + device: The device for computation. + cache_device: The device for storing the output. + batch_dim: The batch dimension of the output tensor. + + Returns: + The output tensor of the block. + """ + + output = [] + nsamples = len(input_ids) + for i in range(0, nsamples, bs): + end_index = min(nsamples, i + bs) + indices = torch.arange(i, end_index).to(torch.long) + tmp_input_ids, tmp_input_others = AutoRound.sampling_inputs( + input_ids, + input_others, + indices, + self.seqlen, + self.batch_dim, + share_cache_keys=self.shared_cache_keys + ) + tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to( + cache_device + ) + if save_output: + if self.batch_size == 1: + output.append(tmp_output) + else: + output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim))) + if self.low_gpu_mem_usage: + clear_memory() + + return output + + # @torch.no_grad() + # def calib(self, nsamples, bs): + # """Perform calibration for quantization. + + # This method calibrates the model for quantization by processing a specified + # number of samples from the calibration dataset. It ensures that the data is + # properly formatted and feeds it to the model. If the number of samples processed + # is less than the specified number, it logs a warning. If no samples are processed, + # it logs an error and exits. + # Args: + # nsamples (int): The number of samples to use for calibration. + # bs (int): The number of samples to use for calibration + # """ + # from .calib_dataset import get_dataloader + # if isinstance(self.dataset, str): + # dataset = self.dataset.replace(" ", "") ##remove all whitespaces + + # # slow here + # self.dataloader = get_dataloader( + # self.tokenizer, + # self.seqlen, + # dataset, + # self.seed, + # bs, + # self.nsamples, + # ) + # else: + # self.dataloader = self.dataset + # total_cnt = 0 + + # # load embed weight if use low_cpu_mem_usage + # if self.low_cpu_mem_usage: + # embed_layers = get_layers_before_block(self.model) + # for n, m in embed_layers: + # m = m.to(self.device) + + # for data in self.dataloader: + # if data is None: + # continue + # if isinstance(data, torch.Tensor): + # input_ids = data.to(self.device) + # data_new = input_ids + # elif isinstance(data, str): + # if self.tokenizer is None: + # logger.error("please provide tokenizer for string input") + # exit(-1) + # data = self.tokenizer(data, truncation=True, max_length=self.seqlen, return_tensors="pt").data + # data_new = {} + # for key in data.keys(): + # data_new[key] = data[key].to(self.device) + # input_ids = data_new["input_ids"] + # elif isinstance(data, tuple) or isinstance(data, list): + # data_new = data + # input_ids = data_new[0] + # else: + # data_new = {} + # for key in data.keys(): + # data_new[key] = to_device(data[key], self.model.device) + # if key == 'images': + # data_new[key] = to_dtype(data_new[key], self.model.dtype) + # input_ids = data_new["input_ids"] + # if input_ids.shape[-1] < self.seqlen: + # continue + # try: + # if isinstance(data_new, torch.Tensor): + # self.model(data_new) + # elif isinstance(data_new, tuple) or isinstance(data_new, list): + # self.model(*data_new) + # else: + # self.model(**data_new) + # except NotImplementedError: + # pass + # except RuntimeError as error: + # logger.warning("When quantization encounters tensor" \ + # " shape mismatch error, you can try to avoid it with batch_size=1") + # logger.error(error) + # pass + # except Exception as error: + # raise error + # total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1 + # if total_cnt >= nsamples: + # break + # if total_cnt == 0: + # logger.error( + # f"no data has been cached, please provide more data with sequence length >={self.seqlen} in the " + # f"dataset or decease the sequence length" + # ) + # exit(-1) + # elif total_cnt < nsamples: + # logger.warning( + # f"An insufficient number of samples likely reduces the accuracy of the quantized model." + # f"Target samples count is {nsamples}, while valid samples count is {total_cnt}" + # ) + + # # clean embed weight to save memory + # if self.low_cpu_mem_usage: + # for n, m in embed_layers: + # m = m.to("meta") + + @torch.no_grad() + def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, last_cache_name=None): + """Attempts to cache intermediate data on GPU, if failed, then using CPU. + + Args: + block_names (list): List of block names to cache data for. + nsamples (int): Number of samples to use for caching. + layer_names (list, optional): List of layer names to cache data for. Defaults to []. + last_cache_name (str, optional): Name of the last cache. Defaults to None. + + Returns: + all_inputs: Cached intermediate data. + + Raises: + Exception: If caching on GPU fails, switches to CPU and caches there. + """ + if layer_names is None: + layer_names = [] + try: + if not self.model.device.type == "meta": + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + pass + else: + self.model = self.model.to(self.device) + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name + ) + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + except RuntimeError as e: + if "CUDA out of memory" in str(e) or "MODULE:PT_DEVMEM" in str(e): + logger.info("switch to cpu to cache block inputs") + if (("lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] < 16) or + self.__class__.__name__ == "AutoRoundMLLM"): + logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. " + f"set `--device '0,1'` in our cmd line usage or " + f"load the model with `device_mapping=auto`," + f" for optimal performance during calibration " + f"Otherwise, the process may be significantly slower.") + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name + ) + else: + raise + return all_inputs + + @torch.no_grad() + def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_name=None): + """Save the inputs of block_name for calibration. + + This method temporarily replaces the forward method of the model to capture + the inputs passing through the specified block. It then calibrates the model + using a specified number of samples. Finally, it restores the original forward + method and returns the inputs for the specified block. + Args: + block_names (list): The names of the blocks for which inputs are to be saved. + layer_names (list):The names of the layers for which inputs are to be saved. + nsamples (int): The number of samples to use for calibration. + last_cache_name (str, optional): The name of the last layer to be cached, + we could break the forward in this layer to save time + + Returns: + dict: A dictionary containing the inputs for the specified block. + """ + if layer_names is None: + layer_names = [] + self.inputs = {} + self.to_cached_layers = block_names + layer_names + tmp_dtype = None + ## have bug if block name is not the first block + if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage: + tmp_dtype = self.model.dtype + self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32) ##model on cpu + + self.last_cache_name = last_cache_name + if last_cache_name is None and len(block_names) + len(layer_names) == 1: + self.last_cache_name = block_names[0] if len(block_names) == 1 else layer_names[0] + # do not set last_cache_name for multimodal models + calib_bs = self.batch_size + self.hook_handles = [] + self._replace_forward() + self.calib(nsamples, calib_bs) + self._recover_forward() + res = self.inputs + del self.last_cache_name + del self.to_cached_layers + if tmp_dtype is not None: + self.model = self.model.to(tmp_dtype) + + return res + + @torch.no_grad() + def get_block_forward_func(self, name): + """Gets the forward function. + + Args: + name (str): The name of the function. + Returns: + function: The forward function. + """ + + def post_process_cache_data(batch_size, data, data_name): + """ + Processes store data for batch handling, reshaping if necessary. + + Args: + batch_size (int): The size of the batch. + data: The data value to store, potentially for caching. + data_name (str): Name of the data. + + Returns: + Processed data or None + """ + new_data = data + if batch_size <= 1: + return new_data + if data_name in self.shared_cache_keys: + return None + if "alibi" in data_name: + if isinstance(data, torch.Tensor): + alibi = data + alibi = alibi.reshape(batch_size, -1, alibi.shape[1], alibi.shape[2]) + new_data = alibi + return new_data + + def forward(m, hidden_states=None, *positional_inputs, **kwargs): + """Rewrite forward function, process and collect input data. + + Args: + hidden_states (torch.Tensor): The hidden states tensor. + *positional_inputs: Variable number of positional arguments. + **kwargs: Variable number of keyword arguments. + + Returns: + NotImplementedError: Getting the first layer inputs and then raise the error to save runtime. + """ + if name not in self.inputs: + self.inputs[name] = {} + init_cache(positional_inputs, self.inputs[name]) + + if self.batch_dim is None: + self.batch_dim = 0 + if hidden_states is not None and self.batch_size > 1: + if hidden_states.shape[0] > self.batch_size: + self.batch_dim = 1 + if len(hidden_states.shape) > 1 and hidden_states.shape[1] > self.batch_size: + logger.error( + f"this model has not been supported, " + f"please raise an issue in https://github.com/intel/auto-round/issues" + f" or try to set the `batch_size` to 1 and " + f"`gradient_accumulate_steps` to your current batch size.") + exit(-1) + + if hidden_states is not None: + kwargs['hidden_states'] = hidden_states + + for key in kwargs.keys(): + if isinstance(kwargs[key], torch.Tensor) or isinstance(kwargs[key], list) \ + or isinstance(kwargs[key], tuple): + if key not in self.inputs[name].keys(): # initialization + data = to_device(kwargs[key], device=torch.device("cpu")) + if data is None or (self.batch_size > 1 and key in self.shared_cache_keys): + self.inputs[name][key] = data + continue + if self.batch_size <= 1: + self.inputs[name][key] = [data] + else: + data = post_process_cache_data(self.batch_size, data, key) + self.inputs[name][key] = list(torch.split(data, 1, dim=self.batch_dim)) + else: # append cache inputs + new_data = post_process_cache_data(self.batch_size, kwargs[key], key) + if new_data is None: # shareable args or NoneType + continue + new_data = to_device(new_data, device=torch.device("cpu")) + if self.batch_size <= 1: + self.inputs[name][key].append(new_data) + else: + self.inputs[name][key].extend(list(torch.split(new_data, 1, dim=self.batch_dim))) + elif isinstance(kwargs[key], (str, bool, type(None))): + if key not in self.inputs[name].keys(): + self.inputs[name][key] = kwargs[key] + else: + # Parameters not to be cached + if check_skippable_keywords(key): + logger.warning_once(f"Please note that '{key}' key" \ + " is not currently used in quantization fine-tuning.") + reset_params(self.inputs[name]) + if name == self.last_cache_name: + raise NotImplementedError + else: + if hidden_states is not None: + kwargs.pop('hidden_states') + return m.orig_forward(hidden_states, *positional_inputs, **kwargs) + else: + # Currently only for Llama-3.2-Vision-Instruct Series + return m.orig_forward(*positional_inputs, **kwargs) + + return forward + + @torch.no_grad() + def _get_cache_data_hook_for_layer(self, name): + """A forward hook to save input max of a module + :param name: the module name + :return: A hook function.""" + + def cache_input_hook(module, inputs, outputs): + input = inputs + if isinstance(inputs, tuple) or isinstance(input, list): + input = inputs[0] + if name in self.inputs: + self.inputs[name].extend(list(torch.split(input.to("cpu"), 1, dim=0))) + else: + self.inputs[name] = list(torch.split(input.to("cpu"), 1, dim=0)) + + return cache_input_hook + + def _recover_forward(self): + """Recovers the forward function.""" + for n, m in self.model.named_modules(): + if hasattr(m, "orig_forward"): + m.forward = m.orig_forward + delattr(m, "orig_forward") + for hook_handle in self.hook_handles: + hook_handle.remove() + self.hook_handles = [] + + def _replace_forward(self): + """Replaces the forward function.""" + from functools import partial + + for n, m in self.model.named_modules(): + if n in self.to_cached_layers and not isinstance(m, tuple(self.supported_types)): ##block + m.orig_forward = m.forward + m.forward = partial(self.get_block_forward_func(n), m) + elif n in self.to_cached_layers: ##linear layer or conv1d layer + hook_func = self._get_cache_data_hook_for_layer(n) + hook_handle = m.register_forward_hook(hook_func) + self.hook_handles.append(hook_handle) + + def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cpu")): + """Quantize a specific layer of the model using the provided inputs. + + Args: + layer_name (str): The name of the layer to quantize. + inputs (torch.Tensor): Input data for quantization. + q_inputs (torch.Tensor, optional): Quantized input data. Defaults to None. + device (torch.device, optional): The device to use for quantization. Defaults to torch.device("cpu"). + + Returns: + None + """ + logger.info(f"quantizing layer {layer_name}") + layer = get_module(self.model, layer_name) + if hasattr(layer, "tuning_device"): + device = layer.tuning_device + + layer = layer.to(device) + for i in range(len(inputs)): + inputs[i] = inputs[i].to(layer.weight.dtype) + if q_inputs is not None: + q_inputs[i] = q_inputs[i].to(layer.weight.dtype) + + wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to( + device) + round_params = [] + minmax_params = [] + for key in wrapper_linear.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(wrapper_linear.params[key]) + else: + round_params.append(wrapper_linear.value) + if self.enable_minmax_tuning: + optimizer = self.optimizer( + [{"params": round_params}, {"params": minmax_params, "lr": self.minmax_lr}], lr=self.lr, weight_decay=0 + ) + else: + optimizer = self.optimizer(round_params, lr=self.lr, weight_decay=0) + + if self.lr_scheduler is None: + lr_schedule = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters + ) + else: + lr_schedule = copy.deepcopy(self.lr_scheduler) + nsamples = len(inputs) + last_best_iter = 0 + best_loss = torch.finfo(torch.float).max + mse_loss = torch.nn.MSELoss().to(device) + scaler = self.get_scaler() # pylint: disable=assignment-from-none + init_loss = None + # best_v, best_min_scale, best_max_scale = torch.tensor(0), torch.tensor(1.0), torch.tensor(1.0) + gradient_accumulate_steps = self.batch_size ##Force to low gpu + batch_size = 1 ##Force to low gpu + pick_samples = batch_size * gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) + if self.sampler != "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + total_loss = 0 + num_elm = 1 + mse_reduction = "mean" + if gradient_accumulate_steps != 1: + mse_reduction = "sum" + mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) + + for i in range(self.iters): + total_loss = 0 + if self.sampler == "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + if gradient_accumulate_steps != 1: + if q_inputs is not None: + current_input = [q_inputs[i] for i in whole_indices] + else: + current_input = [inputs[i] for i in whole_indices] + num_elm = sum(id.numel() for id in current_input) + for tmp_step in range(gradient_accumulate_steps): + indices = whole_indices[tmp_step * batch_size: (tmp_step + 1) * batch_size] + if q_inputs is not None: + current_input = [q_inputs[i] for i in indices] + current_input = torch.cat(current_input, dim=0).to(device) + org_input = [inputs[i] for i in indices] + org_input = torch.cat(org_input, dim=0).to(device) + else: + current_input = [inputs[i] for i in indices] + current_input = torch.cat(current_input, dim=0).to(device) + org_input = current_input + with torch.no_grad(): + current_output = layer(org_input) + + if self.amp: + with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + output_q = wrapper_linear(current_input) # pylint: disable=not-callable + loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + else: + output_q = wrapper_linear(current_input) # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32), current_output.to(torch.float32) + ) + total_loss += loss.item() / num_elm + + self.scale_loss_and_backward(scaler, loss) + if i == 0: + init_loss = total_loss + + if total_loss < best_loss: + best_loss = total_loss + if not self.not_use_best_mse: + best_params = collect_best_params(wrapper_linear) + last_best_iter = i + if self.not_use_best_mse and i == self.iters - 1: + best_params = collect_best_params(wrapper_linear) + + if not self.not_use_best_mse: + if 0 < self.dynamic_max_gap <= i - last_best_iter: + break + self.step(scaler, optimizer, lr_schedule) + + last_loss = total_loss + best_iter = self.iters + if not self.not_use_best_mse: + last_loss = best_loss + best_iter = last_best_iter + with torch.no_grad(): + unwrapper_layer(self.model, wrapper_linear, layer_name, best_params) + mv_module_from_gpu(layer, self.low_cpu_mem_usage) + dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" + logger.info(dump_info) + + def register_act_max_hook(self, model): + def get_act_max_hook(module, input, output): + if isinstance(input, (tuple, list)): + input = input[0] + if not hasattr(module, "act_max"): + module.act_max = torch.abs(input).max().item() + else: + module.act_max = max(torch.abs(input).max().item(), module.act_max) + + hook_handles = [] + + for n, m in model.named_modules(): + if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + return hook_handles + + def quant_block(self, block, input_ids, input_others, q_input=None, device=torch.device("cpu")): + """Quantize the weights of a given block of the model. + + Args: + block: The block of the model to be quantized. + input_ids: The input tensor containing tokenized input ids. + input_others: A dictionary containing additional input data. + q_input: The quantized input tensor. + device: The device for quantization. + + Returns: + Tuple: (q_outputs, output) if self.enable_quanted_input is True, else (None, output) + """ + if self.device_map is not None: + from accelerate import dispatch_model + for n, m in block.named_modules(): + if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): + continue + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + hook = AlignDevicesHook(m.tuning_device, io_same_device=True) + add_hook_to_module(m, hook, True) + + if q_input is None: + hook_handles = self.register_act_max_hook(block) + + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, + device, + self.cache_device) + + for handle in hook_handles: + handle.remove() + else: + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, + device, + self.cache_device) + hook_handles = self.register_act_max_hook(block) + self.get_block_outputs(block, q_input, input_others, self.batch_size * self.infer_bs_coeff, + device, self.cache_device, save_output=False) + + for handle in hook_handles: + handle.remove() + + if q_input is not None: + if input_ids is not q_input: + clear_memory(input_ids) + else: + clear_memory() + input_ids = q_input + + quantized_layer_names, unquantized_layer_names = wrapper_block( + block, self.enable_minmax_tuning, self.enable_norm_bias_tuning, device=self.device) + + round_params = [] + minmax_params = [] + for n, m in block.named_modules(): + if hasattr(m, "orig_layer"): + for key in m.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(m.params[key]) + else: + round_params.append(m.params[key]) + + if self.enable_minmax_tuning: + optimizer = self.optimizer( + [{"params": round_params}, {"params": minmax_params, "lr": self.minmax_lr}], lr=self.lr, weight_decay=0 + ) + else: + optimizer = self.optimizer(round_params, lr=self.lr, weight_decay=0) + + if len(round_params) + len(minmax_params) <= 0: + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + f"layers in the block" + ) + logger.info(dump_info) + return output, output + + if self.lr_scheduler is None: + lr_schedule = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters + ) + else: + lr_schedule = copy.deepcopy(self.lr_scheduler) + + nsamples = len(input_ids) + pick_samples = self.batch_size * self.gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) + if self.sampler != "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + last_best_iter = 0 + best_loss = torch.finfo(torch.float).max + num_elm = 1 + mse_reduction = "mean" + if self.gradient_accumulate_steps != 1: + mse_reduction = "sum" + mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) + scaler = self.get_scaler() # pylint: disable=assignment-from-none + init_loss = None + best_params = {} + total_loss = 0 + + for i in range(self.iters): + total_loss = 0 + if self.sampler == "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + ##we assume the block input and output shape is same + if self.gradient_accumulate_steps != 1: + current_input_ids = [input_ids[i] for i in whole_indices] + num_elm = sum(id.numel() for id in current_input_ids) + for tmp_step in range(self.gradient_accumulate_steps): + indices = whole_indices[tmp_step * self.batch_size: (tmp_step + 1) * self.batch_size] + current_input_ids, current_input_others = AutoRound.sampling_inputs( + input_ids, + input_others, + indices, + seqlen=self.seqlen, + batch_dim=self.batch_dim, + share_cache_keys=self.shared_cache_keys + ) + + current_output = [output[x] for x in indices] + current_output = torch.cat(current_output, dim=self.batch_dim) + + current_output = to_device(current_output, device) + + output_q = block_forward( + block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device + ) + if self.amp: + with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + else: + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32), current_output.to(torch.float32) + ) + + total_loss += loss.item() / num_elm + self.scale_loss_and_backward(scaler, loss) + + if i == 0: + init_loss = total_loss + + if total_loss < best_loss: + best_loss = total_loss + if not self.not_use_best_mse: + best_params = collect_best_params(block) + # print(f"get better result at iter {i}, the loss is {total_loss}", flush=True) + + last_best_iter = i + if self.not_use_best_mse and i == self.iters - 1: + best_params = collect_best_params(block) + + if not self.not_use_best_mse: + if 0 < self.dynamic_max_gap <= i - last_best_iter: + break + self.step(scaler, optimizer, lr_schedule) + + last_loss = total_loss + best_iter = self.iters + if not self.not_use_best_mse: + last_loss = best_loss + best_iter = last_best_iter + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" + ) + logger.info(dump_info) + if len(unquantized_layer_names) != 0: + logger.info(f"{unquantized_layer_names} have not been quantized") + with torch.no_grad(): + unwrapper_block(block, best_params) + if self.enable_quanted_input: + if self.low_cpu_mem_usage: + block = block.to(device) + clear_memory() + q_outputs = self.get_block_outputs( + block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, + cache_device=self.cache_device + ) + if self.device_map is not None: + accelerate.hooks.remove_hook_from_submodules( + block) + mv_module_from_gpu(block, self.low_cpu_mem_usage) + clear_memory(input_ids) + + return q_outputs, output + + else: + if self.device_map is not None: + accelerate.hooks.remove_hook_from_submodules( + block) + mv_module_from_gpu(block, self.low_cpu_mem_usage) + clear_memory(input_ids) + return None, output + + def quant_blocks( + self, + model: torch.nn.Module, + inputs, + block_names, + nblocks=1, + device="cpu", + pbar=None + ): + """Quantize and dequantize the weights of the specified blocks in the model. + + Args: + model: The PyTorch model to be quantized. + inputs: The input data for quantization. + block_names: The names of the blocks to be quantized and dequantized. + nblocks: The number of blocks to quantize and dequantize. + device: The device for quantization and dequantization. + + Returns: + None + """ + q_input = None + clear_memory() + for n, m in model.named_parameters(): + m.requires_grad_(False) + input_ids = inputs["input_ids"] + inputs.pop("input_ids", None) + input_others = inputs + clear_memory() + input_ids = to_device(input_ids, self.cache_device) + input_others = to_device(input_others, self.cache_device) + ## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage + tmp_dtype = self.amp_dtype if self.amp else torch.float32 + for i in range(len(input_ids)): + input_ids[i] = input_ids[i].to(tmp_dtype) + + for key in input_others.keys(): + if isinstance(input_others[key], torch.Tensor) and ( + input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16 + ): + input_others[key] = input_others[key].to(tmp_dtype) + elif isinstance(input_others[key], list): + for i in range(len(input_others[key])): + to_dtype(input_others[key][i], tmp_dtype) + if self.enable_torch_compile: + quant_block = compile_func(self.quant_block, device) + else: + quant_block = self.quant_block + + if pbar is None: + pbar = tqdm(range(0, len(block_names), nblocks)) + + for n, m in self.model.named_modules(): + if isinstance(m, tuple(self.supported_types)): + m.name = n + + for i in range(0, len(block_names), nblocks): + if i != 0: + pbar.update(1) + if nblocks == 1: + n = block_names[i] + pbar.set_description(f"Quantizing {n}") + m = get_module(model, n) + else: + names = block_names[i: min(i + nblocks, len(block_names))] + pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}") + modules = [get_module(model, n) for n in names] + m = WrapperMultiblock(modules) + + if not self.model.device.type == "meta" or self.low_cpu_mem_usage: + m = m.to(device) + + q_input, input_ids = quant_block( + m, + input_ids, + input_others, + q_input=q_input, + device=device, + ) + if self.is_packing_immediate: + from auto_round.export import PACKING_LAYER_WITH_FORMAT + for _, tmp_m in m.named_modules(): + if hasattr(tmp_m, "bits") and check_to_quantized(tmp_m): + target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0] + PACKING_LAYER_WITH_FORMAT[target_backend](tmp_m.name, self.model, self.formats[0]) + pbar.set_description(f"Quantizing done") + pbar.update(1) + pbar.close() + + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + for n, m in self.model.named_modules(): + if hasattr(m, "name"): + delattr(m, "name") + + del q_input + del input_ids + del input_others + del inputs + + clear_memory() + + def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): + """Save the quantized model to the specified output directory in the specified format. + + Args: + output_dir (str, optional): The directory to save the quantized model. Defaults to None. + format (str, optional): The format in which to save the model. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place. Defaults to True. + **kwargs: Additional keyword arguments specific to the export format. + + Returns: + object: The compressed model object. + """ + # only support to export afp8 + if self.act_bits <= 8: + if "fp8" not in self.act_data_type or self.act_dynamic: + if format != "fake": + logger.warning( + f"Currently only support to export auto_round format quantized model" + " with fp8 dtype activation for activation quantization." + " Change format to fake and save." + ) + format = "fake" + else: + if format != "auto_round": + logger.warning( + f"Currently only support to export auto_round format for static W{self.bits}AFP8 model," + " change format to auto_round" + ) + format = "auto_round" + + if re.search("q\d_k", format) and not self.data_type.endswith("_dq"): + logger.error( + f"datatype<{self.data_type}> not support to export {format} format." + " Please change export format or data_type." + ) + sys.exit(-1) + + if self.low_cpu_mem_usage: + self.model = self.model.to('cpu') + + if not self.quantized: + logger.warning("please run autoround.quantize first") + return + if format == "fake" or format == "qdq": ##TODO fix act quantizaiton later + self.model = self.model.to("cpu") + self.model.save_pretrained(output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + processor = kwargs.get("processor", None) + if processor is not None: + processor.save_pretrained(output_dir) + return + if self.act_bits <= 8 and format == "qdq": + logger.warning( + "Support for exporting activation quantization is limited. " + "Please ensure that your configuration is supported.") + if format in ["gguf:q4_0", "gguf:q4_1"]: + if self.group_size != 32: + logger.error(f"{format} need group_size=32, but it is {self.group_size}, cannot export.") + return + if format == "gguf:q4_0" and not self.sym: + logger.warning(f"incorrect format choose, will reset to gguf:q4_1") + if format == "gguf:q4_1" and self.sym: + logger.warning(f"incorrect format choose, will reset to gguf:q4_0") + + from auto_round.export import EXPORT_FORMAT + backend = format + format = format.split(":")[0] + if format not in EXPORT_FORMAT: + logger.error(f"export format only supports {EXPORT_FORMAT.keys()}") + raise ValueError(f"export format only supports {EXPORT_FORMAT.keys()}, but got {format}") + save_quantized_as_format = EXPORT_FORMAT.get(format) + if "gptq" in format and not self.sym: + logger.warning( + "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/4/8 bits)." + ) + if "awq" in format and not self.bits == 4: + raise ValueError("The AWQ format only supports W4 quantization ") + + if isinstance(self.dataset, str): + self.serialization_keys.append("dataset") + serialization_dict = {} + for key in self.serialization_keys: + serialization_dict[key] = getattr(self, key) + from .version import __version__ + + serialization_dict["autoround_version"] = __version__ + if "scale_dtype" in serialization_dict.keys(): + serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) + + compressed_model = save_quantized_as_format( ##TODO refine the code + output_dir, + model=self.model, + layer_config=self.layer_config, + inplace=inplace, + bits=self.bits, + group_size=self.group_size, + sym=self.sym, + iters=self.iters, + lr=self.lr, + minmax_lr=self.minmax_lr, + enable_minmax_tuning=self.enable_minmax_tuning, + enable_quanted_input=self.enable_quanted_input, + scale_dtype=self.scale_dtype, + tokenizer=self.tokenizer, + supported_types=self.supported_types, + data_type=self.data_type, + serialization_dict=serialization_dict, + backend=backend, + to_quant_block_names=self.to_quant_block_names, + quant_block_list=self.quant_block_list, + **kwargs + ) + return compressed_model + + def get_quantized_layer_names_outside_blocks(self): + """Gets the names of quantized layers outside blocks in the model. + + Returns: + list: List of layer names outside blocks. + """ + if self.layer_config is None or len(self.layer_config) == 0: + return [] + + layer_names = [] + all_layers_in_block = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list) + + for key in self.layer_config.keys(): + if key in all_layers_in_block: + continue + layer = get_module(self.model, key) + if layer is None: + logger.error(f"could not find layer {key} in the model, exit...") + exit(-1) + if isinstance(layer, tuple(self.supported_types)) and check_to_quantized(self.layer_config[key]): + layer_names.append(key) + + return layer_names + + def set_amp_dtype(self): + self.amp_dtype = torch.float16 + if self.model.dtype != torch.float32: + self.amp_dtype = self.model.dtype + if self.device == "cpu" or "hpu" in self.device: + self.amp_dtype = torch.bfloat16 + if self.amp: + if self.device == "cpu" and not CpuInfo().bf16: + self.amp = False + self.amp_dtype = torch.float32 + self.model = self.model.to(torch.float32) + logger.warning( + f"amp is set to FALSE as the current {self.device} device does not support the 'bf16' data type." + ) + else: + self.model = self.model.to(self.amp_dtype) + else: + self.amp_dtype = torch.float32 + self.model = self.model.to(torch.float32) + + def get_optimizer(self, optimizer): + """Returns the specified optimizer. In SignRound, we fix the optimizer. + + Args: + optimizer: The optimizer to be used. + + Returns: + The specified optimizer. + """ + from auto_round.sign_sgd import SignSGD + + return SignSGD + + def get_scaler(self): + """Returns scaler, in SignRound, no need to use scaler.""" + return None + + def scale_loss_and_backward(self, scaler, loss): + """Scales the loss and performs backward pass. + + Args: + scaler: The scaler to be used. + loss: The loss to be scaled. + + Returns: + The scaled loss. + """ + scale_loss = loss * 1000 + scale_loss.backward() + if is_optimum_habana_available(): + htcore.mark_step() + return scale_loss + + def step(self, scaler, optimizer, lr_schedule): + """Performs a step in the optimization process. + + Args: + scaler: The scaler to be used. + optimizer: The optimizer for the step. + lr_schedule: The learning rate schedule. + + Returns: + None + """ + optimizer.step() + # for hpu + if is_optimum_habana_available(): + htcore.mark_step() + optimizer.zero_grad() + lr_schedule.step() + + @classmethod + @torch.no_grad() + def sampling_inputs(cls, input_ids, input_others, indices, seqlen, + batch_dim=0, share_cache_keys=()): + """Samples inputs based on the given indices and sequence length. + + Args: + input_ids: The list of input tensor containing input_ids. + input_others: A dictionary containing other input data. + indices: The indices to sample from the input. + seqlen: The sequence length. + + Returns: + current_input_ids: The sampled input IDs. + current_input_others: The sampled other input data. + """ + current_input_ids = [input_ids[i] for i in indices] + + current_input_ids = torch.cat(current_input_ids, dim=batch_dim) + + current_input_others = {"positional_inputs": input_others["positional_inputs"]} + for key in input_others.keys(): + if "positional_inputs" in key: + continue + if (key not in share_cache_keys or len(indices) == 1) \ + and not isinstance(input_others[key], (str, bool, type(None))): + current_input_others[key] = None + if input_others[key] is not None: + current_input_others[key] = [input_others[key][i] for i in indices] + if len(indices) == 1: + current_input_others[key] = current_input_others[key][0] + else: + try: + current_input_others[key] = torch.cat(current_input_others[key], dim=0) + except TypeError as err: + logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") + else: + current_input_others[key] = input_others[key] + + return current_input_ids, current_input_others + +class AutoRoundDM(object): + """For more information, please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent + for the quantization of llms." arXiv preprint arXiv:2309.05516 (2023). + + Args: + model: The PyTorch model to be quantized. + tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied. + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether symmetric quantization is to be used (default is True). + layer_config (dict): Configuration for weight quantization (default is None). + layer_config={ + 'layer1':##layer_name + { + 'data_type': 'int', + 'bits': 4, + 'group_size': 128, + 'sym': True + 'act_data_type': None, + 'act_bits': 16, + 'act_group_size': None, + 'act_sym': None, + + } + ... + } + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for tuning (default is "auto"). + lr_scheduler: The learning rate scheduler to be used. + dataset (str): The default dataset name (default is "partiprompts"). + enable_quanted_input (bool): Whether to use the output of the previous quantized block as + the input for the current block (default is True). + enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True). + lr (float): The learning rate (default is None, will be set to 1.0/iters). + minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). + low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). + iters (int): Number of iterations (default is 200). + seqlen (int): Data length of the sequence for tuning (default is 2048). + nsamples (int): Number of samples (default is 128). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + nblocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + have different choices. + act_bits (int): Number of bits for activation quantization. Default is 16. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. + to_quant_block_names (str|list): A string or list whose elements are list of + block's layer names to be quantized. + enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning + enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. + device_map (str|dict): device map for each block + Returns: + The quantized model. + """ + def __init__( + self, + model: torch.nn.Module, + tokenizer, + bits: int = 4, + group_size: int = 128, + sym: bool = True, + layer_config: dict = None, + batch_size: int = 8, + amp: bool = True, + device: str = None, + lr_scheduler=None, + # dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", + enable_quanted_input: bool = True, + enable_minmax_tuning: bool = True, + lr: float = None, + minmax_lr: float = None, + low_gpu_mem_usage: bool = False, + low_cpu_mem_usage: bool = False, + iters: int = 200, + seqlen: int = 2048, + nsamples: int = 128, + sampler: str = "rand", + seed: int = 42, + gradient_accumulate_steps: int = 1, + not_use_best_mse: bool = False, + dynamic_max_gap: int = -1, + data_type: str = "int", + scale_dtype: str = "fp16", + act_bits: int = 16, + act_group_size: int = None, + act_sym: bool = None, + act_data_type: str = None, + act_dynamic: bool = True, + to_quant_block_names: Union[str, list] = None, + enable_norm_bias_tuning: bool = False, + enable_torch_compile: bool = False, + device_map: Union[str, dict] = None, + super_bits: int = None, + super_group_size: int = None, + model_kwargs: dict = None, + **kwargs, + ): + self.quantized = False + self.model_orig_dtype = model.dtype + self.seed = seed + set_seed(self.seed) + assert not unsupport_meta_device(model), ( + "AutoRound does not support for params on meta device." + " Please use more gpus by setting `--device 0,1,2,3` or just use one gpu") + + ## important tuning hype-parameters + self.amp = amp + self.enable_quanted_input = enable_quanted_input + self.enable_minmax_tuning = enable_minmax_tuning + self.nsamples = nsamples + self.bits = bits + self.enable_norm_bias_tuning = enable_norm_bias_tuning + self.group_size = group_size + self.sym = sym + + self.low_gpu_mem_usage = low_gpu_mem_usage + self.low_cpu_mem_usage = low_cpu_mem_usage + self.layer_config = {} if layer_config is None else layer_config + self.seqlen = seqlen + self.batch_size, self.gradient_accumulate_steps = batch_size, gradient_accumulate_steps + self.nblocks = nblocks + self.dataset = dataset + self.iters = iters + if self.iters < 0: + logger.warning("`iters` must be non-negative, reset it to 200") + self.iters = 200 + if self.iters == 0: + self.lr = 5e-3 + else: + self.lr = lr or (1.0 / self.iters) ##must after iter setting + self.minmax_lr = minmax_lr or self.lr + + self.data_type = data_type + tmp_bits = infer_bits_by_data_type(self.data_type) + if tmp_bits<16 and tmp_bits!=bits: + logger.warning( + f"bits set in 'data_type' do not match the specified 'bits' setting. Resetting 'bits' to {tmp_bits}.") + self.bits = tmp_bits + self.supported_types = supported_layer_types + self.model = model.eval() + self.tokenizer = tokenizer + self.device = detect_device(device) + self.scale_dtype = convert_dtype_str2torch(scale_dtype) + self.set_amp_dtype() + self.to_quant_block_names = to_quant_block_names + if not hasattr(self, 'quant_block_list'): + all_blocks = get_block_names(model) + self.quant_block_list = find_matching_blocks(model, all_blocks, self.to_quant_block_names) + self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device + + ##activation + self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size + self.act_bits = act_bits if not (act_bits is None) else self.bits + self.act_sym = act_sym if not (act_sym is None) else self.sym + self.act_dynamic = act_dynamic + self.act_data_type = act_data_type + if self.act_data_type is None: + if data_type in supported_dtypes and self.act_bits <= 16: + self.act_data_type = data_type + logger.info(f"activation adopts {data_type}") + else: + self.act_data_type = "float" + + tmp_act_bits = infer_bits_by_data_type(self.act_data_type) + if tmp_act_bits < 16: + self.act_bits = tmp_act_bits + + self.sampler = sampler + self.not_use_best_mse = not_use_best_mse + self.dynamic_max_gap = dynamic_max_gap + self.lr_scheduler = lr_scheduler + self.optimizer = self.get_optimizer(None) + self.batch_dim = None + self.infer_bs_coeff = 1 + + self.super_bits = super_bits + self.super_group_size = super_group_size + + torch.set_printoptions(precision=3, sci_mode=True) + self.check_configs() + if self.act_bits <= 8 and self.amp_dtype == torch.float16: + logger.warning("force to use bf16 to for quantization tuning when enabling activation quantization") + self.amp_dtype = torch.bfloat16 + self.model = self.model.to(torch.bfloat16) + else: + logger.info(f"using {self.model.dtype} for quantization tuning") + + self.enable_torch_compile = enable_torch_compile + if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \ + and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type: + logger.info("'enable_torch_compile' is set to `False` by default. " \ + "Enabling it can reduce tuning cost by 20%, but it might throw an exception.") + + if self.act_bits <= 8 and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled") + + if self.low_cpu_mem_usage == True and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled") + + if is_debug_mode() and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as debug mode is enabled") + + if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") + + if is_optimum_habana_available(): + logger.info("Optimum Habana is available, import htcore explicitly.") + import habana_frameworks.torch.core as htcore # pylint: disable=E0401 + import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] + self.device_map = device_map + + self.set_device_map_in_blocks(self.device_map) + + self.is_packing_immediate = False ## whether to pack the layer immediately after tuning + + self.serialization_keys = [ + "bits", + "group_size", + "sym", + "data_type", + "enable_quanted_input", + "enable_minmax_tuning", + "data_type", + "seqlen", + "batch_size", + "scale_dtype", + "lr", + "minmax_lr", + "gradient_accumulate_steps", + "iters", + "amp", + "nsamples", + "low_gpu_mem_usage", + "to_quant_block_names", + "enable_norm_bias_tuning", + "act_bits", + "act_group_size", + "act_sym", + "act_dynamic", + "act_data_type", + "super_bits", + "super_group_size" + ] + + self.has_qlayer_outside_block = self.set_layerwise_config(self.layer_config) ##better place in the end + self.shared_cache_keys = get_shared_keys(self.model) + + def set_device_map_in_blocks(self, device_map): + """Sets the device map for specific blocks in the model. + + Args: + device_map (Union[str, dict]): A mapping of module names to devices. + If provided as a string, it should be in the format + "module_name:device,module_name:device". Devices can be integers + (GPU IDs) or strings (e.g., 'cpu', 'cuda:0'). + """ + if self.device_map is None or len(self.device_map) == 0: + self.device_map = None + if not device_map: + return + if isinstance(device_map, str): + device_map = device_map.replace(" ", "") + infos = device_map.split(",") + device_map_dict = {} + for info in infos: + index = info.find(':') + key = info[:index] + value = info[index + 1:] + device_map_dict[key] = value + device_map = device_map_dict + + names = [n for n, m in self.model.named_modules() if len(list(m.children())) == 0] + + for key, device in device_map.items(): + if isinstance(device, str) and device.isdigit(): + device = int(device) + device = detect_device(device) + try: + module = get_module(self.model, key) + module.tuning_device = device + except: + matching_names = [name for name in names if re.match(key, name)] + if len(matching_names) > 0: + for name in matching_names: + self._set_device_for_matching_module(name, device) + else: + for name in names: + if key in name: + self._set_device_for_matching_module(name, device) + + def _set_device_for_matching_module(self, name, device): + module = get_module(self.model, name) + if hasattr(module, "tuning_device") and module.tuning_device != device: + logger.warning( + f"Multiple devices have been set for layer {name}, keeping original device {module.tuning_device}") + else: + module.tuning_device = device + + def _dq_check(self): + """Reset the default value of super_bits and super_group_size""" + from auto_round.export.export_to_gguf.config import GGUF_CONFIG + if self.data_type.endswith("_dq"): + gguf_config = GGUF_CONFIG[f"gguf:q{self.bits}_k_s"] + self.super_bits = gguf_config["super_bits"] if self.super_bits is None else self.super_bits + self.super_group_size = gguf_config["super_group_size"] \ + if self.super_group_size is None else self.super_group_size + + def check_configs(self): + + """Checks if the configurations are valid. + + Raises: + AssertionError: If any of the configurations are invalid. + """ + assert isinstance(self.model, torch.nn.Module) + assert self.bits > 0, "bits must be positive" + assert self.act_bits > 0, "bits must be positive" + assert self.group_size == -1 or self.group_size >= 1, "only supports positive group_size or -1(per channel)" + assert self.act_group_size == -1 or self.act_group_size >= 1, \ + "only supports positive group_size or -1(per channel)" + assert self.batch_size > 0, "batch size must be positive" + assert self.iters >= 0, "iters must be non-negative" + assert self.seqlen > 0, "seqlen must be positive" + assert self.nblocks > 0, "nblocks must be positive" + assert self.gradient_accumulate_steps > 0, "gradient accumulate step must be positive" + # assert self.tokenizer != None or self.dataloader != None + if self.act_bits <= 8: + logger.warning( + "activation quantization is an experimental feature with limited support and a complex API. " + "And please save the quantized model to fake format as real deployment is not supported currently") + + if "mx_fp" in self.data_type: + logger.warning( + "please save the quantized model to fake format " + "as real deployment is not supported for mx_fp datatype currently") + + if "mx_fp" in self.data_type and self.group_size != 32: + logger.warning("mx_fp should only support group_size of 32 in real deployment") + + if self.nsamples < self.gradient_accumulate_steps * self.batch_size: + if self.batch_size > self.nsamples: + logger.warning(f"reset batch_size to {self.nsamples} as nsamples({self.nsamples})" + f" is smaller than batch_size({self.batch_size})") + self.batch_size = self.nsamples + if self.gradient_accumulate_steps > self.nsamples // self.batch_size: + self.gradient_accumulate_steps = self.nsamples // self.batch_size + logger.warning( + f"reset gradient_accumulate_steps to {self.gradient_accumulate_steps}" + f" as nsamples must equal or greater" + f" than gradient_accumulate_steps * batch_size") + self._dq_check() + + # def _check_format_compatibility(self, format): ##TODO + # ##check lm_head, mixed_bits, bits, each layer supporting, etc + # pass + + def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "auto_round", inplace=True, **kwargs): + """Quantizes the model and saves it in the specified format(s). + + This function checks the validity of the requested format(s), quantizes + the model accordingly, and saves it to the specified output directory. + If multiple formats are provided, the model is saved separately for each format. + + Args: + output_dir (str, optional): The directory where the quantized model + will be saved. Defaults to "tmp_autoround". + format (str, optional): The quantization format(s) to use, separated + by commas if multiple. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place if only + one format is used. Defaults to True. + **kwargs: Additional arguments for the quantization and saving process. + + Returns: + model: A qdq model or packed model based on the configurations + folders: The folder paths where the quantized models are saved. + + Raises: + ValueError: If an unsupported format is specified. + """ + # Validate and process the specified formats + formats = format.replace(' ', '').split(',') + from auto_round.utils import supported_formats + for format_ in formats: + if format_ not in supported_formats: + logger.error(f"Unsupported format {format_}, please choose from {supported_formats}") + exit(-1) + + # only support to export afp8 + if self.act_bits <= 8: + if "fp8" not in self.act_data_type: + if len(formats) > 1 or "fake" not in formats: + logger.warning( + f"Currently only support to export auto_round format quantized model" + " with fp8 dtype activation for activation quantization." + " Change format to fake and save." + ) + formats = ["fake"] + else: + if len(formats) > 1 or "auto_round" not in formats: + logger.warning( + f"Currently only support to export auto_round format for W{self.bits}AFP8 model," + " change format to auto_round" + ) + formats = ["auto_round"] + + # If multiple formats are specified, enforce inplace=False + if len(formats) > 1: + inplace = False + inplace = kwargs.get("inplace", inplace) + kwargs.pop("inplace", None) + + # Determine if immediate packing is required + if (len(formats) == 1 and + ("awq" in formats[0] or "gptq" in formats[0] or "auto_round" in formats[0]) and + not self.has_qlayer_outside_block and inplace): # TODO: Support more formats + self.is_packing_immediate = True + + # Adjust format settings based on compatibility + for index in range(len(formats)): + format = formats[index] + if "auto_round" in format: + if (self.sym and ("gptq" not in format and "awq" not in format)) or self.bits == 3: + format = format.replace('auto_round', 'auto_round:auto_gptq') + formats[index] = format + + # Remove duplicates from formats list + def remove_duplicates(lst): + seen = set() + return [x for x in lst if not (x in seen or seen.add(x))] + + formats = remove_duplicates(formats) + self.formats = formats + + # # Check format compatibility + # self._check_format_compatibility(formats) + + # Perform model quantization + model, _ = self.quantize() + + # Save the quantized model in the specified formats + folders = [] + for format in formats: + if "gptq" in format and not self.sym: + logger.warning( + "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/4/8 bits)." + ) + save_format_ = format.replace(":", "-").replace("_", "-") + save_folder = os.path.join(output_dir, save_format_) if len(formats) > 1 else output_dir + self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs) + + folders.append(save_folder) + + return model, folders + + @torch.inference_mode + def quantize_rtn(self): + if self.amp: + self.model.to(self.amp_dtype) + self.model.to("cpu") + all_to_quantized_module_names = [] + for n, m in self.model.named_modules(): + if check_to_quantized(m): + all_to_quantized_module_names.append(n) + pbar = tqdm(all_to_quantized_module_names) + + for name in pbar: + pbar.set_description(f"Quantizing {name}") + m = get_module(self.model, name) + + m.to(self.device) + m = WrapperLinear(m, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False) + m = m.unwrapper({}) + m.to("cpu") + if self.low_gpu_mem_usage: + clear_memory() + if self.is_packing_immediate: + from auto_round.export import PACKING_LAYER_WITH_FORMAT + if check_to_quantized(m): + target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0] + PACKING_LAYER_WITH_FORMAT[target_backend](name, self.model, self.formats[0]) + if self.low_gpu_mem_usage: + clear_memory() + else: + set_module(self.model, name, m) + + self.quantized = True + return self.model, self.layer_config + + def quantize(self): + """Quantize the model and return the quantized model along with layer configurations. + the entry of AutoRound. + + Returns: + The quantized model and layer configurations. + """ + if self.iters == 0: + return self.quantize_rtn() + + if bool(self.quant_block_list): + all_blocks = self.quant_block_list + else: + all_blocks = get_block_names(self.model) + + if len(all_blocks) == 0: + logger.warning("could not find blocks, exit with original model") + return self.model, self.layer_config + + if self.amp: + self.model = self.model.to(self.amp_dtype) + + layer_names = self.get_quantized_layer_names_outside_blocks() + self.start_time = time.time() + all_first_block_names = [block[0] for block in all_blocks] + logger.info("start to cache block inputs") + all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules(self.model) ##self.model.hf_device_map has not been changed + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + logger.info("caching done") + pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) + + for block_names in all_blocks: + inputs = all_inputs[block_names[0]] + all_inputs.pop(block_names[0]) + keys = inputs.keys() + input_id_str = [key for key in keys if key.startswith('hidden_state')] + if len(input_id_str) != 1: + raise RuntimeError(f"hidden_states arg mismatch error," + "please raise an issue in https://github.com/intel/auto-round/issues") + inputs["input_ids"] = inputs.pop(input_id_str[0], None) + clear_memory(self.inputs) + + if "input_ids" in inputs.keys(): + total_samples = len(inputs["input_ids"]) + if total_samples < self.batch_size: + self.batch_size = total_samples + logger.warning(f"force the train batch size to {total_samples}") + + self.quant_blocks( + self.model, + inputs, + block_names, + nblocks=self.nblocks, + device=self.device, + pbar=pbar + ) + if self.is_packing_immediate: + assert len(self.formats) == 1 + + self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately + + end_time = time.time() + cost_time = end_time - self.start_time + logger.info(f"quantization tuning time {cost_time}") + + ## dump a summary + quantized_layers = [] + unquantized_layers = [] + for n, m in self.model.named_modules(): + if isinstance(m, tuple(self.supported_types)): + if check_to_quantized(m): + quantized_layers.append(n) + else: + unquantized_layers.append(n) + elif hasattr(m, "scales") or hasattr(m, "scale"): ##packing_immediately + quantized_layers.append(n) + summary_info = ( + f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model" + ) + if len(unquantized_layers) > 0: + summary_info += f", {unquantized_layers} have not been quantized" + logger.info(summary_info) + + self.quantized = True + return self.model, self.layer_config + + def quant_layers(self, layer_names, layer_inputs): + """Quantizes specified layers based on inputs and configuration. + + Args: + layer_names (list): List of layer names to quantize. + layer_inputs (dict): Dictionary mapping layer names to input data. + + Returns: + None + """ + ##TODO currently we take all the layers outside blocks as post block layers which is not optimal + if len(layer_names) == 0: + return + q_layer_inputs = None + enable_quanted_input = self.enable_quanted_input + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1 and enable_quanted_input: + from accelerate.big_modeling import dispatch_model + + dispatch_model(self.model, self.model.hf_device_map) + + if enable_quanted_input: + q_layer_inputs = self.try_cache_inter_data_gpucpu([], self.nsamples, layer_names=layer_names) + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules( + self.model) ##self.model.hf_device_map has not been changed + + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + if self.enable_torch_compile: + quant_layer = compile_func(self.quant_layer, self.device) + else: + quant_layer = self.quant_layer + for layer_name in layer_names: + layer_input = layer_inputs[layer_name] + layer_input = to_device(layer_input, self.cache_device) + q_layer_input = q_layer_inputs[layer_name] if enable_quanted_input else None + q_layer_input = to_device(q_layer_input, self.cache_device) + quant_layer(layer_name, layer_input, q_layer_input, device=self.device) + del layer_input + clear_memory(q_layer_input) + + def set_layerwise_config(self, layer_config): + """ + Sets the layer-wise configuration based on the provided `layer_config`. + By default, only quantize layers in blocks. + + Args: + layer_config (dict): The configuration dictionary for each layer containing various configuration options. + + Returns: + bool: Returns True if there are quantized layers outside the blocks (e.g., lm-head), + otherwise returns False. + """ + # Get the names of layers in quantization blocks + layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list) + + ##process regex in layer_config + all_supported_layer_names = [] + # List of configuration keys + keys = self.serialization_keys + + for n, m in self.model.named_modules(): + # Delete previous configuration to avoid conflicts with prior tuning + for key in keys: + if hasattr(m, key): + delattr(m, key) + + # Skip unsupported types + if not isinstance(m, tuple(self.supported_types)): + continue + all_supported_layer_names.append(n) + + names_in_layer_config = list(layer_config.keys()) + for name in names_in_layer_config: + if name in all_supported_layer_names: + continue + matched_names = [] + for layer_name in all_supported_layer_names: + if re.search(re.compile(name), layer_name) is not None: + matched_names.append(layer_name) + if len(matched_names) > 0: + val = layer_config[name] + layer_config.pop(name) + for match_name in matched_names: + layer_config[match_name] = val + else: + raise ValueError(f"key {name} in layer_config is invalid, please have a double check") + + has_qlayer_outside_block = False # Flag to track if there are quantized layers outside blocks (e.g., lm-head) + + # Iterate through all modules in the model + for n, m in self.model.named_modules(): + + # Skip unsupported types + if not isinstance(m, tuple(self.supported_types)): + continue + + # If the layer is not in the config and is part of a quantization block, use default configuration + if n not in layer_config.keys() and n in layers_in_blocks: + layer_config[n] = {} + for key in keys: + layer_config[n][key] = getattr(self, key) + # If the layer is partially configured, fill in missing values + elif n in layer_config.keys(): + for key in keys: + if key not in layer_config[n].keys(): + layer_config[n][key] = getattr(self, key) + # If the layer is not in the config and not part of a quantization block, + # use default configuration and set specific values + else: + layer_config[n] = {} + for key in keys: + layer_config[n][key] = getattr(self, key) + layer_config[n]["bits"] = 16 + layer_config[n]["act_bits"] = 16 + + if n in layers_in_blocks: + layer_config[n]["in_blocks"] = True + else: + layer_config[n]["in_blocks"] = False + + # If the layer is outside a block and requires quantization, mark it as a quantized layer outside the block + if n not in layers_in_blocks and check_to_quantized(layer_config[n]): + has_qlayer_outside_block = True + + in_features, out_features = get_layer_features(m) + if in_features <= layer_config[n]["group_size"]: + layer_config[n]["group_size"] = -1 + + # Apply the configuration to the corresponding layer in the model + for key in keys: + setattr(m, key, layer_config[n][key]) + + # Return whether there are quantized layers outside the blocks + return has_qlayer_outside_block + + @torch.no_grad() + def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device, save_output=True): + """Compute the output of a given block of the model for a given input. + + Args: + block: The block of the model. + input_ids: The input tensor containing tokenized input ids. + input_others: A dictionary containing additional input data. + bs: The batch size for computing the output. + device: The device for computation. + cache_device: The device for storing the output. + batch_dim: The batch dimension of the output tensor. + + Returns: + The output tensor of the block. + """ + + output = [] + nsamples = len(input_ids) + for i in range(0, nsamples, bs): + end_index = min(nsamples, i + bs) + indices = torch.arange(i, end_index).to(torch.long) + tmp_input_ids, tmp_input_others = AutoRound.sampling_inputs( + input_ids, + input_others, + indices, + self.seqlen, + self.batch_dim, + share_cache_keys=self.shared_cache_keys + ) + tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to( + cache_device + ) + if save_output: + if self.batch_size == 1: + output.append(tmp_output) + else: + output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim))) + if self.low_gpu_mem_usage: + clear_memory() + + return output + + @torch.no_grad() + def calib(self, nsamples, bs): + """Perform calibration for quantization. + + This method calibrates the model for quantization by processing a specified + number of samples from the calibration dataset. It ensures that the data is + properly formatted and feeds it to the model. If the number of samples processed + is less than the specified number, it logs a warning. If no samples are processed, + it logs an error and exits. + Args: + nsamples (int): The number of samples to use for calibration. + bs (int): The number of samples to use for calibration + """ + from .calib_dataset import get_dataloader + if isinstance(self.dataset, str): + dataset = self.dataset.replace(" ", "") ##remove all whitespaces + + # slow here + self.dataloader = get_dataloader( + self.tokenizer, + self.seqlen, + dataset, + self.seed, + bs, + self.nsamples, + ) + else: + self.dataloader = self.dataset + total_cnt = 0 + + # load embed weight if use low_cpu_mem_usage + if self.low_cpu_mem_usage: + embed_layers = get_layers_before_block(self.model) + for n, m in embed_layers: + m = m.to(self.device) + + for data in self.dataloader: + if data is None: + continue + if isinstance(data, torch.Tensor): + input_ids = data.to(self.device) + data_new = input_ids + elif isinstance(data, str): + if self.tokenizer is None: + logger.error("please provide tokenizer for string input") + exit(-1) + data = self.tokenizer(data, truncation=True, max_length=self.seqlen, return_tensors="pt").data + data_new = {} + for key in data.keys(): + data_new[key] = data[key].to(self.device) + input_ids = data_new["input_ids"] + elif isinstance(data, tuple) or isinstance(data, list): + data_new = data + input_ids = data_new[0] + else: + data_new = {} + for key in data.keys(): + data_new[key] = to_device(data[key], self.model.device) + if key == 'images': + data_new[key] = to_dtype(data_new[key], self.model.dtype) + input_ids = data_new["input_ids"] + if input_ids.shape[-1] < self.seqlen: + continue + try: + if isinstance(data_new, torch.Tensor): + self.model(data_new) + elif isinstance(data_new, tuple) or isinstance(data_new, list): + self.model(*data_new) + else: + self.model(**data_new) + except NotImplementedError: + pass + except RuntimeError as error: + logger.warning("When quantization encounters tensor" \ + " shape mismatch error, you can try to avoid it with batch_size=1") + logger.error(error) + pass + except Exception as error: + raise error + total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1 + if total_cnt >= nsamples: + break + if total_cnt == 0: + logger.error( + f"no data has been cached, please provide more data with sequence length >={self.seqlen} in the " + f"dataset or decease the sequence length" + ) + exit(-1) + elif total_cnt < nsamples: + logger.warning( + f"An insufficient number of samples likely reduces the accuracy of the quantized model." + f"Target samples count is {nsamples}, while valid samples count is {total_cnt}" + ) + + # clean embed weight to save memory + if self.low_cpu_mem_usage: + for n, m in embed_layers: + m = m.to("meta") + + @torch.no_grad() + def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, last_cache_name=None): + """Attempts to cache intermediate data on GPU, if failed, then using CPU. + + Args: + block_names (list): List of block names to cache data for. + nsamples (int): Number of samples to use for caching. + layer_names (list, optional): List of layer names to cache data for. Defaults to []. + last_cache_name (str, optional): Name of the last cache. Defaults to None. + + Returns: + all_inputs: Cached intermediate data. + + Raises: + Exception: If caching on GPU fails, switches to CPU and caches there. + """ + if layer_names is None: + layer_names = [] + try: + if not self.model.device.type == "meta": + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + pass + else: + self.model = self.model.to(self.device) + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name + ) + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + except RuntimeError as e: + if "CUDA out of memory" in str(e) or "MODULE:PT_DEVMEM" in str(e): + logger.info("switch to cpu to cache block inputs") + if (("lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] < 16) or + self.__class__.__name__ == "AutoRoundMLLM"): + logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. " + f"set `--device '0,1'` in our cmd line usage or " + f"load the model with `device_mapping=auto`," + f" for optimal performance during calibration " + f"Otherwise, the process may be significantly slower.") + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name + ) + else: + raise + return all_inputs + + @torch.no_grad() + def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_name=None): + """Save the inputs of block_name for calibration. + + This method temporarily replaces the forward method of the model to capture + the inputs passing through the specified block. It then calibrates the model + using a specified number of samples. Finally, it restores the original forward + method and returns the inputs for the specified block. + Args: + block_names (list): The names of the blocks for which inputs are to be saved. + layer_names (list):The names of the layers for which inputs are to be saved. + nsamples (int): The number of samples to use for calibration. + last_cache_name (str, optional): The name of the last layer to be cached, + we could break the forward in this layer to save time + + Returns: + dict: A dictionary containing the inputs for the specified block. + """ + if layer_names is None: + layer_names = [] + self.inputs = {} + self.to_cached_layers = block_names + layer_names + tmp_dtype = None + ## have bug if block name is not the first block + if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage: + tmp_dtype = self.model.dtype + self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32) ##model on cpu + + self.last_cache_name = last_cache_name + if last_cache_name is None and len(block_names) + len(layer_names) == 1: + self.last_cache_name = block_names[0] if len(block_names) == 1 else layer_names[0] + # do not set last_cache_name for multimodal models + calib_bs = self.batch_size + self.hook_handles = [] + self._replace_forward() + self.calib(nsamples, calib_bs) + self._recover_forward() + res = self.inputs + del self.last_cache_name + del self.to_cached_layers + if tmp_dtype is not None: + self.model = self.model.to(tmp_dtype) + + return res + + @torch.no_grad() + def get_block_forward_func(self, name): + """Gets the forward function. + + Args: + name (str): The name of the function. + Returns: + function: The forward function. + """ + + def post_process_cache_data(batch_size, data, data_name): + """ + Processes store data for batch handling, reshaping if necessary. + + Args: + batch_size (int): The size of the batch. + data: The data value to store, potentially for caching. + data_name (str): Name of the data. + + Returns: + Processed data or None + """ + new_data = data + if batch_size <= 1: + return new_data + if data_name in self.shared_cache_keys: + return None + if "alibi" in data_name: + if isinstance(data, torch.Tensor): + alibi = data + alibi = alibi.reshape(batch_size, -1, alibi.shape[1], alibi.shape[2]) + new_data = alibi + return new_data + + def forward(m, hidden_states=None, *positional_inputs, **kwargs): + """Rewrite forward function, process and collect input data. + + Args: + hidden_states (torch.Tensor): The hidden states tensor. + *positional_inputs: Variable number of positional arguments. + **kwargs: Variable number of keyword arguments. + + Returns: + NotImplementedError: Getting the first layer inputs and then raise the error to save runtime. + """ + if name not in self.inputs: + self.inputs[name] = {} + init_cache(positional_inputs, self.inputs[name]) + + if self.batch_dim is None: + self.batch_dim = 0 + if hidden_states is not None and self.batch_size > 1: + if hidden_states.shape[0] > self.batch_size: + self.batch_dim = 1 + if len(hidden_states.shape) > 1 and hidden_states.shape[1] > self.batch_size: + logger.error( + f"this model has not been supported, " + f"please raise an issue in https://github.com/intel/auto-round/issues" + f" or try to set the `batch_size` to 1 and " + f"`gradient_accumulate_steps` to your current batch size.") + exit(-1) + + if hidden_states is not None: + kwargs['hidden_states'] = hidden_states + + for key in kwargs.keys(): + if isinstance(kwargs[key], torch.Tensor) or isinstance(kwargs[key], list) \ + or isinstance(kwargs[key], tuple): + if key not in self.inputs[name].keys(): # initialization + data = to_device(kwargs[key], device=torch.device("cpu")) + if data is None or (self.batch_size > 1 and key in self.shared_cache_keys): + self.inputs[name][key] = data + continue + if self.batch_size <= 1: + self.inputs[name][key] = [data] + else: + data = post_process_cache_data(self.batch_size, data, key) + self.inputs[name][key] = list(torch.split(data, 1, dim=self.batch_dim)) + else: # append cache inputs + new_data = post_process_cache_data(self.batch_size, kwargs[key], key) + if new_data is None: # shareable args or NoneType + continue + new_data = to_device(new_data, device=torch.device("cpu")) + if self.batch_size <= 1: + self.inputs[name][key].append(new_data) + else: + self.inputs[name][key].extend(list(torch.split(new_data, 1, dim=self.batch_dim))) + elif isinstance(kwargs[key], (str, bool, type(None))): + if key not in self.inputs[name].keys(): + self.inputs[name][key] = kwargs[key] + else: + # Parameters not to be cached + if check_skippable_keywords(key): + logger.warning_once(f"Please note that '{key}' key" \ + " is not currently used in quantization fine-tuning.") + reset_params(self.inputs[name]) + if name == self.last_cache_name: + raise NotImplementedError + else: + if hidden_states is not None: + kwargs.pop('hidden_states') + return m.orig_forward(hidden_states, *positional_inputs, **kwargs) + else: + # Currently only for Llama-3.2-Vision-Instruct Series + return m.orig_forward(*positional_inputs, **kwargs) + + return forward + + @torch.no_grad() + def _get_cache_data_hook_for_layer(self, name): + """A forward hook to save input max of a module + :param name: the module name + :return: A hook function.""" + + def cache_input_hook(module, inputs, outputs): + input = inputs + if isinstance(inputs, tuple) or isinstance(input, list): + input = inputs[0] + if name in self.inputs: + self.inputs[name].extend(list(torch.split(input.to("cpu"), 1, dim=0))) + else: + self.inputs[name] = list(torch.split(input.to("cpu"), 1, dim=0)) + + return cache_input_hook + + def _recover_forward(self): + """Recovers the forward function.""" + for n, m in self.model.named_modules(): + if hasattr(m, "orig_forward"): + m.forward = m.orig_forward + delattr(m, "orig_forward") + for hook_handle in self.hook_handles: + hook_handle.remove() + self.hook_handles = [] + + def _replace_forward(self): + """Replaces the forward function.""" + from functools import partial + + for n, m in self.model.named_modules(): + if n in self.to_cached_layers and not isinstance(m, tuple(self.supported_types)): ##block + m.orig_forward = m.forward + m.forward = partial(self.get_block_forward_func(n), m) + elif n in self.to_cached_layers: ##linear layer or conv1d layer + hook_func = self._get_cache_data_hook_for_layer(n) + hook_handle = m.register_forward_hook(hook_func) + self.hook_handles.append(hook_handle) + + def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cpu")): + """Quantize a specific layer of the model using the provided inputs. + + Args: + layer_name (str): The name of the layer to quantize. + inputs (torch.Tensor): Input data for quantization. + q_inputs (torch.Tensor, optional): Quantized input data. Defaults to None. + device (torch.device, optional): The device to use for quantization. Defaults to torch.device("cpu"). + + Returns: + None + """ + logger.info(f"quantizing layer {layer_name}") + layer = get_module(self.model, layer_name) + if hasattr(layer, "tuning_device"): + device = layer.tuning_device + + layer = layer.to(device) + for i in range(len(inputs)): + inputs[i] = inputs[i].to(layer.weight.dtype) + if q_inputs is not None: + q_inputs[i] = q_inputs[i].to(layer.weight.dtype) + + wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to( + device) + round_params = [] + minmax_params = [] + for key in wrapper_linear.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(wrapper_linear.params[key]) + else: + round_params.append(wrapper_linear.value) + if self.enable_minmax_tuning: + optimizer = self.optimizer( + [{"params": round_params}, {"params": minmax_params, "lr": self.minmax_lr}], lr=self.lr, weight_decay=0 + ) + else: + optimizer = self.optimizer(round_params, lr=self.lr, weight_decay=0) + + if self.lr_scheduler is None: + lr_schedule = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters + ) + else: + lr_schedule = copy.deepcopy(self.lr_scheduler) + nsamples = len(inputs) + last_best_iter = 0 + best_loss = torch.finfo(torch.float).max + mse_loss = torch.nn.MSELoss().to(device) + scaler = self.get_scaler() # pylint: disable=assignment-from-none + init_loss = None + # best_v, best_min_scale, best_max_scale = torch.tensor(0), torch.tensor(1.0), torch.tensor(1.0) + gradient_accumulate_steps = self.batch_size ##Force to low gpu + batch_size = 1 ##Force to low gpu + pick_samples = batch_size * gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) + if self.sampler != "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + total_loss = 0 + num_elm = 1 + mse_reduction = "mean" + if gradient_accumulate_steps != 1: + mse_reduction = "sum" + mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) + + for i in range(self.iters): + total_loss = 0 + if self.sampler == "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + if gradient_accumulate_steps != 1: + if q_inputs is not None: + current_input = [q_inputs[i] for i in whole_indices] + else: + current_input = [inputs[i] for i in whole_indices] + num_elm = sum(id.numel() for id in current_input) + for tmp_step in range(gradient_accumulate_steps): + indices = whole_indices[tmp_step * batch_size: (tmp_step + 1) * batch_size] + if q_inputs is not None: + current_input = [q_inputs[i] for i in indices] + current_input = torch.cat(current_input, dim=0).to(device) + org_input = [inputs[i] for i in indices] + org_input = torch.cat(org_input, dim=0).to(device) + else: + current_input = [inputs[i] for i in indices] + current_input = torch.cat(current_input, dim=0).to(device) + org_input = current_input + with torch.no_grad(): + current_output = layer(org_input) + + if self.amp: + with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + output_q = wrapper_linear(current_input) # pylint: disable=not-callable + loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + else: + output_q = wrapper_linear(current_input) # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32), current_output.to(torch.float32) + ) + total_loss += loss.item() / num_elm + + self.scale_loss_and_backward(scaler, loss) + if i == 0: + init_loss = total_loss + + if total_loss < best_loss: + best_loss = total_loss + if not self.not_use_best_mse: + best_params = collect_best_params(wrapper_linear) + last_best_iter = i + if self.not_use_best_mse and i == self.iters - 1: + best_params = collect_best_params(wrapper_linear) + + if not self.not_use_best_mse: + if 0 < self.dynamic_max_gap <= i - last_best_iter: + break + self.step(scaler, optimizer, lr_schedule) + + last_loss = total_loss + best_iter = self.iters + if not self.not_use_best_mse: + last_loss = best_loss + best_iter = last_best_iter + with torch.no_grad(): + unwrapper_layer(self.model, wrapper_linear, layer_name, best_params) + mv_module_from_gpu(layer, self.low_cpu_mem_usage) + dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" + logger.info(dump_info) + + def register_act_max_hook(self, model): + def get_act_max_hook(module, input, output): + if isinstance(input, (tuple, list)): + input = input[0] + if not hasattr(module, "act_max"): + module.act_max = torch.abs(input).max().item() + else: + module.act_max = max(torch.abs(input).max().item(), module.act_max) + + hook_handles = [] + + for n, m in model.named_modules(): + if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + return hook_handles + + def quant_block(self, block, input_ids, input_others, q_input=None, device=torch.device("cpu")): + """Quantize the weights of a given block of the model. + + Args: + block: The block of the model to be quantized. + input_ids: The input tensor containing tokenized input ids. + input_others: A dictionary containing additional input data. + q_input: The quantized input tensor. + device: The device for quantization. + + Returns: + Tuple: (q_outputs, output) if self.enable_quanted_input is True, else (None, output) + """ + if self.device_map is not None: + from accelerate import dispatch_model + for n, m in block.named_modules(): + if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): + continue + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + hook = AlignDevicesHook(m.tuning_device, io_same_device=True) + add_hook_to_module(m, hook, True) + + if q_input is None: + hook_handles = self.register_act_max_hook(block) + + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, + device, + self.cache_device) + + for handle in hook_handles: + handle.remove() + else: + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, + device, + self.cache_device) + hook_handles = self.register_act_max_hook(block) + self.get_block_outputs(block, q_input, input_others, self.batch_size * self.infer_bs_coeff, + device, self.cache_device, save_output=False) + + for handle in hook_handles: + handle.remove() + + if q_input is not None: + if input_ids is not q_input: + clear_memory(input_ids) + else: + clear_memory() + input_ids = q_input + + quantized_layer_names, unquantized_layer_names = wrapper_block( + block, self.enable_minmax_tuning, self.enable_norm_bias_tuning, device=self.device) + + round_params = [] + minmax_params = [] + for n, m in block.named_modules(): + if hasattr(m, "orig_layer"): + for key in m.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(m.params[key]) + else: + round_params.append(m.params[key]) + + if self.enable_minmax_tuning: + optimizer = self.optimizer( + [{"params": round_params}, {"params": minmax_params, "lr": self.minmax_lr}], lr=self.lr, weight_decay=0 + ) + else: + optimizer = self.optimizer(round_params, lr=self.lr, weight_decay=0) + + if len(round_params) + len(minmax_params) <= 0: + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + f"layers in the block" + ) + logger.info(dump_info) + return output, output + + if self.lr_scheduler is None: + lr_schedule = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters + ) + else: + lr_schedule = copy.deepcopy(self.lr_scheduler) + + nsamples = len(input_ids) + pick_samples = self.batch_size * self.gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) + if self.sampler != "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + last_best_iter = 0 + best_loss = torch.finfo(torch.float).max + num_elm = 1 + mse_reduction = "mean" + if self.gradient_accumulate_steps != 1: + mse_reduction = "sum" + mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) + scaler = self.get_scaler() # pylint: disable=assignment-from-none + init_loss = None + best_params = {} + total_loss = 0 + + for i in range(self.iters): + total_loss = 0 + if self.sampler == "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + ##we assume the block input and output shape is same + if self.gradient_accumulate_steps != 1: + current_input_ids = [input_ids[i] for i in whole_indices] + num_elm = sum(id.numel() for id in current_input_ids) + for tmp_step in range(self.gradient_accumulate_steps): + indices = whole_indices[tmp_step * self.batch_size: (tmp_step + 1) * self.batch_size] + current_input_ids, current_input_others = AutoRound.sampling_inputs( + input_ids, + input_others, + indices, + seqlen=self.seqlen, + batch_dim=self.batch_dim, + share_cache_keys=self.shared_cache_keys + ) + + current_output = [output[x] for x in indices] + current_output = torch.cat(current_output, dim=self.batch_dim) + + current_output = to_device(current_output, device) + + output_q = block_forward( + block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device + ) + if self.amp: + with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + else: + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32), current_output.to(torch.float32) + ) + + total_loss += loss.item() / num_elm + self.scale_loss_and_backward(scaler, loss) + + if i == 0: + init_loss = total_loss + + if total_loss < best_loss: + best_loss = total_loss + if not self.not_use_best_mse: + best_params = collect_best_params(block) + # print(f"get better result at iter {i}, the loss is {total_loss}", flush=True) + + last_best_iter = i + if self.not_use_best_mse and i == self.iters - 1: + best_params = collect_best_params(block) + + if not self.not_use_best_mse: + if 0 < self.dynamic_max_gap <= i - last_best_iter: + break + self.step(scaler, optimizer, lr_schedule) + + last_loss = total_loss + best_iter = self.iters + if not self.not_use_best_mse: + last_loss = best_loss + best_iter = last_best_iter + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" + ) + logger.info(dump_info) + if len(unquantized_layer_names) != 0: + logger.info(f"{unquantized_layer_names} have not been quantized") + with torch.no_grad(): + unwrapper_block(block, best_params) + if self.enable_quanted_input: + if self.low_cpu_mem_usage: + block = block.to(device) + clear_memory() + q_outputs = self.get_block_outputs( + block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, + cache_device=self.cache_device + ) + if self.device_map is not None: + accelerate.hooks.remove_hook_from_submodules( + block) + mv_module_from_gpu(block, self.low_cpu_mem_usage) + clear_memory(input_ids) + + return q_outputs, output + + else: + if self.device_map is not None: + accelerate.hooks.remove_hook_from_submodules( + block) + mv_module_from_gpu(block, self.low_cpu_mem_usage) + clear_memory(input_ids) + return None, output + + def quant_blocks( + self, + model: torch.nn.Module, + inputs, + block_names, + nblocks=1, + device="cpu", + pbar=None + ): + """Quantize and dequantize the weights of the specified blocks in the model. + + Args: + model: The PyTorch model to be quantized. + inputs: The input data for quantization. + block_names: The names of the blocks to be quantized and dequantized. + nblocks: The number of blocks to quantize and dequantize. + device: The device for quantization and dequantization. + + Returns: + None + """ + q_input = None + clear_memory() + for n, m in model.named_parameters(): + m.requires_grad_(False) + input_ids = inputs["input_ids"] + inputs.pop("input_ids", None) + input_others = inputs + clear_memory() + input_ids = to_device(input_ids, self.cache_device) + input_others = to_device(input_others, self.cache_device) + ## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage + tmp_dtype = self.amp_dtype if self.amp else torch.float32 + for i in range(len(input_ids)): + input_ids[i] = input_ids[i].to(tmp_dtype) + + for key in input_others.keys(): + if isinstance(input_others[key], torch.Tensor) and ( + input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16 + ): + input_others[key] = input_others[key].to(tmp_dtype) + elif isinstance(input_others[key], list): + for i in range(len(input_others[key])): + to_dtype(input_others[key][i], tmp_dtype) + if self.enable_torch_compile: + quant_block = compile_func(self.quant_block, device) + else: + quant_block = self.quant_block + + if pbar is None: + pbar = tqdm(range(0, len(block_names), nblocks)) + + for n, m in self.model.named_modules(): + if isinstance(m, tuple(self.supported_types)): + m.name = n + + for i in range(0, len(block_names), nblocks): + if i != 0: + pbar.update(1) + if nblocks == 1: + n = block_names[i] + pbar.set_description(f"Quantizing {n}") + m = get_module(model, n) + else: + names = block_names[i: min(i + nblocks, len(block_names))] + pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}") + modules = [get_module(model, n) for n in names] + m = WrapperMultiblock(modules) + + if not self.model.device.type == "meta" or self.low_cpu_mem_usage: + m = m.to(device) + + q_input, input_ids = quant_block( + m, + input_ids, + input_others, + q_input=q_input, + device=device, + ) + if self.is_packing_immediate: + from auto_round.export import PACKING_LAYER_WITH_FORMAT + for _, tmp_m in m.named_modules(): + if hasattr(tmp_m, "bits") and check_to_quantized(tmp_m): + target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0] + PACKING_LAYER_WITH_FORMAT[target_backend](tmp_m.name, self.model, self.formats[0]) + pbar.set_description(f"Quantizing done") + pbar.update(1) + pbar.close() + + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + for n, m in self.model.named_modules(): + if hasattr(m, "name"): + delattr(m, "name") + + del q_input + del input_ids + del input_others + del inputs + + clear_memory() + + def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): + """Save the quantized model to the specified output directory in the specified format. + + Args: + output_dir (str, optional): The directory to save the quantized model. Defaults to None. + format (str, optional): The format in which to save the model. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place. Defaults to True. + **kwargs: Additional keyword arguments specific to the export format. + + Returns: + object: The compressed model object. + """ + # only support to export afp8 + if self.act_bits <= 8: + if "fp8" not in self.act_data_type or self.act_dynamic: + if format != "fake": + logger.warning( + f"Currently only support to export auto_round format quantized model" + " with fp8 dtype activation for activation quantization." + " Change format to fake and save." + ) + format = "fake" + else: + if format != "auto_round": + logger.warning( + f"Currently only support to export auto_round format for static W{self.bits}AFP8 model," + " change format to auto_round" + ) + format = "auto_round" + + if re.search("q\d_k", format) and not self.data_type.endswith("_dq"): + logger.error( + f"datatype<{self.data_type}> not support to export {format} format." + " Please change export format or data_type." + ) + sys.exit(-1) + + if self.low_cpu_mem_usage: + self.model = self.model.to('cpu') + + if not self.quantized: + logger.warning("please run autoround.quantize first") + return + if format == "fake" or format == "qdq": ##TODO fix act quantizaiton later + self.model = self.model.to("cpu") + self.model.save_pretrained(output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + processor = kwargs.get("processor", None) + if processor is not None: + processor.save_pretrained(output_dir) + return + if self.act_bits <= 8 and format == "qdq": + logger.warning( + "Support for exporting activation quantization is limited. " + "Please ensure that your configuration is supported.") + if format in ["gguf:q4_0", "gguf:q4_1"]: + if self.group_size != 32: + logger.error(f"{format} need group_size=32, but it is {self.group_size}, cannot export.") + return + if format == "gguf:q4_0" and not self.sym: + logger.warning(f"incorrect format choose, will reset to gguf:q4_1") + if format == "gguf:q4_1" and self.sym: + logger.warning(f"incorrect format choose, will reset to gguf:q4_0") + + from auto_round.export import EXPORT_FORMAT + backend = format + format = format.split(":")[0] + if format not in EXPORT_FORMAT: + logger.error(f"export format only supports {EXPORT_FORMAT.keys()}") + raise ValueError(f"export format only supports {EXPORT_FORMAT.keys()}, but got {format}") + save_quantized_as_format = EXPORT_FORMAT.get(format) + if "gptq" in format and not self.sym: + logger.warning( + "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/4/8 bits)." + ) + if "awq" in format and not self.bits == 4: + raise ValueError("The AWQ format only supports W4 quantization ") + + if isinstance(self.dataset, str): + self.serialization_keys.append("dataset") + serialization_dict = {} + for key in self.serialization_keys: + serialization_dict[key] = getattr(self, key) + from .version import __version__ + + serialization_dict["autoround_version"] = __version__ + if "scale_dtype" in serialization_dict.keys(): + serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) + + compressed_model = save_quantized_as_format( ##TODO refine the code + output_dir, + model=self.model, + layer_config=self.layer_config, + inplace=inplace, + bits=self.bits, + group_size=self.group_size, + sym=self.sym, + iters=self.iters, + lr=self.lr, + minmax_lr=self.minmax_lr, + enable_minmax_tuning=self.enable_minmax_tuning, + enable_quanted_input=self.enable_quanted_input, + scale_dtype=self.scale_dtype, + tokenizer=self.tokenizer, + supported_types=self.supported_types, + data_type=self.data_type, + serialization_dict=serialization_dict, + backend=backend, + to_quant_block_names=self.to_quant_block_names, + quant_block_list=self.quant_block_list, + **kwargs + ) + return compressed_model + + def get_quantized_layer_names_outside_blocks(self): + """Gets the names of quantized layers outside blocks in the model. + + Returns: + list: List of layer names outside blocks. + """ + if self.layer_config is None or len(self.layer_config) == 0: + return [] + + layer_names = [] + all_layers_in_block = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list) + + for key in self.layer_config.keys(): + if key in all_layers_in_block: + continue + layer = get_module(self.model, key) + if layer is None: + logger.error(f"could not find layer {key} in the model, exit...") + exit(-1) + if isinstance(layer, tuple(self.supported_types)) and check_to_quantized(self.layer_config[key]): + layer_names.append(key) + + return layer_names + + def set_amp_dtype(self): + self.amp_dtype = torch.float16 + if self.model.dtype != torch.float32: + self.amp_dtype = self.model.dtype + if self.device == "cpu" or "hpu" in self.device: + self.amp_dtype = torch.bfloat16 + if self.amp: + if self.device == "cpu" and not CpuInfo().bf16: + self.amp = False + self.amp_dtype = torch.float32 + self.model = self.model.to(torch.float32) + logger.warning( + f"amp is set to FALSE as the current {self.device} device does not support the 'bf16' data type." + ) + else: + self.model = self.model.to(self.amp_dtype) + else: + self.amp_dtype = torch.float32 + self.model = self.model.to(torch.float32) + + def get_optimizer(self, optimizer): + """Returns the specified optimizer. In SignRound, we fix the optimizer. + + Args: + optimizer: The optimizer to be used. + + Returns: + The specified optimizer. + """ + from auto_round.sign_sgd import SignSGD + + return SignSGD + + def get_scaler(self): + """Returns scaler, in SignRound, no need to use scaler.""" + return None + + def scale_loss_and_backward(self, scaler, loss): + """Scales the loss and performs backward pass. + + Args: + scaler: The scaler to be used. + loss: The loss to be scaled. + + Returns: + The scaled loss. + """ + scale_loss = loss * 1000 + scale_loss.backward() + if is_optimum_habana_available(): + htcore.mark_step() + return scale_loss + + def step(self, scaler, optimizer, lr_schedule): + """Performs a step in the optimization process. + + Args: + scaler: The scaler to be used. + optimizer: The optimizer for the step. + lr_schedule: The learning rate schedule. + + Returns: + None + """ + optimizer.step() + # for hpu + if is_optimum_habana_available(): + htcore.mark_step() + optimizer.zero_grad() + lr_schedule.step() + + @classmethod + @torch.no_grad() + def sampling_inputs(cls, input_ids, input_others, indices, seqlen, + batch_dim=0, share_cache_keys=()): + """Samples inputs based on the given indices and sequence length. + + Args: + input_ids: The list of input tensor containing input_ids. + input_others: A dictionary containing other input data. + indices: The indices to sample from the input. + seqlen: The sequence length. + + Returns: + current_input_ids: The sampled input IDs. + current_input_others: The sampled other input data. + """ + current_input_ids = [input_ids[i] for i in indices] + + current_input_ids = torch.cat(current_input_ids, dim=batch_dim) + + current_input_others = {"positional_inputs": input_others["positional_inputs"]} + for key in input_others.keys(): + if "positional_inputs" in key: + continue + if (key not in share_cache_keys or len(indices) == 1) \ + and not isinstance(input_others[key], (str, bool, type(None))): + current_input_others[key] = None + if input_others[key] is not None: + current_input_others[key] = [input_others[key][i] for i in indices] + if len(indices) == 1: + current_input_others[key] = current_input_others[key][0] + else: + try: + current_input_others[key] = torch.cat(current_input_others[key], dim=0) + except TypeError as err: + logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") + else: + current_input_others[key] = input_others[key] + + return current_input_ids, current_input_others + + +class AdaRoundDM(object): + """For more information, please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent + for the quantization of llms." arXiv preprint arXiv:2309.05516 (2023). + + Args: + model: The PyTorch model to be quantized. + tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied. + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether symmetric quantization is to be used (default is True). + layer_config (dict): Configuration for weight quantization (default is None). + layer_config={ + 'layer1':##layer_name + { + 'data_type': 'int', + 'bits': 4, + 'group_size': 128, + 'sym': True + 'act_data_type': None, + 'act_bits': 16, + 'act_group_size': None, + 'act_sym': None, + + } + ... + } + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for tuning (default is "auto"). + lr_scheduler: The learning rate scheduler to be used. + dataset (str): The default dataset name (default is "partiprompts"). + enable_quanted_input (bool): Whether to use the output of the previous quantized block as + the input for the current block (default is True). + enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True). + lr (float): The learning rate (default is None, will be set to 1.0/iters). + minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). + low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). + iters (int): Number of iterations (default is 200). + seqlen (int): Data length of the sequence for tuning (default is 2048). + nsamples (int): Number of samples (default is 128). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + nblocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + have different choices. + act_bits (int): Number of bits for activation quantization. Default is 16. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. + to_quant_block_names (str|list): A string or list whose elements are list of + block's layer names to be quantized. + enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning + enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. + device_map (str|dict): device map for each block + Returns: + The quantized model. + """ + def __init__( + self, + model: torch.nn.Module, + tokenizer, + bits: int = 4, + group_size: int = 128, + sym: bool = True, + layer_config: dict = None, + batch_size: int = 8, + amp: bool = True, + device: str = None, + lr_scheduler=None, + # dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", + enable_quanted_input: bool = True, + enable_minmax_tuning: bool = True, + lr: float = None, + minmax_lr: float = None, + low_gpu_mem_usage: bool = False, + low_cpu_mem_usage: bool = False, + iters: int = 200, + seqlen: int = 2048, + nsamples: int = 128, + sampler: str = "rand", + seed: int = 42, + gradient_accumulate_steps: int = 1, + not_use_best_mse: bool = False, + dynamic_max_gap: int = -1, + data_type: str = "int", + scale_dtype: str = "fp16", + act_bits: int = 16, + act_group_size: int = None, + act_sym: bool = None, + act_data_type: str = None, + act_dynamic: bool = True, + to_quant_block_names: Union[str, list] = None, + enable_norm_bias_tuning: bool = False, + enable_torch_compile: bool = False, + device_map: Union[str, dict] = None, + super_bits: int = None, + super_group_size: int = None, + model_kwargs: dict = None, + **kwargs, + ): + self.quantized = False + self.model_orig_dtype = model.dtype + self.seed = seed + set_seed(self.seed) + assert not unsupport_meta_device(model), ( + "AutoRound does not support for params on meta device." + " Please use more gpus by setting `--device 0,1,2,3` or just use one gpu") + + ## important tuning hype-parameters + self.amp = amp + self.enable_quanted_input = enable_quanted_input + self.enable_minmax_tuning = enable_minmax_tuning + self.nsamples = nsamples + self.bits = bits + self.enable_norm_bias_tuning = enable_norm_bias_tuning + self.group_size = group_size + self.sym = sym + + self.low_gpu_mem_usage = low_gpu_mem_usage + self.low_cpu_mem_usage = low_cpu_mem_usage + self.layer_config = {} if layer_config is None else layer_config + self.seqlen = seqlen + self.batch_size, self.gradient_accumulate_steps = batch_size, gradient_accumulate_steps + self.nblocks = nblocks + self.dataset = dataset + self.iters = iters + if self.iters < 0: + logger.warning("`iters` must be non-negative, reset it to 200") + self.iters = 200 + if self.iters == 0: + self.lr = 5e-3 + else: + self.lr = lr or (1.0 / self.iters) ##must after iter setting + self.minmax_lr = minmax_lr or self.lr + + self.data_type = data_type + tmp_bits = infer_bits_by_data_type(self.data_type) + if tmp_bits<16 and tmp_bits!=bits: + logger.warning( + f"bits set in 'data_type' do not match the specified 'bits' setting. Resetting 'bits' to {tmp_bits}.") + self.bits = tmp_bits + self.supported_types = supported_layer_types + self.model = model.eval() + self.tokenizer = tokenizer + self.device = detect_device(device) + self.scale_dtype = convert_dtype_str2torch(scale_dtype) + self.set_amp_dtype() + self.to_quant_block_names = to_quant_block_names + if not hasattr(self, 'quant_block_list'): + all_blocks = get_block_names(model) + self.quant_block_list = find_matching_blocks(model, all_blocks, self.to_quant_block_names) + self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device + + ##activation + self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size + self.act_bits = act_bits if not (act_bits is None) else self.bits + self.act_sym = act_sym if not (act_sym is None) else self.sym + self.act_dynamic = act_dynamic + self.act_data_type = act_data_type + if self.act_data_type is None: + if data_type in supported_dtypes and self.act_bits <= 16: + self.act_data_type = data_type + logger.info(f"activation adopts {data_type}") + else: + self.act_data_type = "float" + + tmp_act_bits = infer_bits_by_data_type(self.act_data_type) + if tmp_act_bits < 16: + self.act_bits = tmp_act_bits + + self.sampler = sampler + self.not_use_best_mse = not_use_best_mse + self.dynamic_max_gap = dynamic_max_gap + self.lr_scheduler = lr_scheduler + self.optimizer = self.get_optimizer(None) + self.batch_dim = None + self.infer_bs_coeff = 1 + + self.super_bits = super_bits + self.super_group_size = super_group_size + + torch.set_printoptions(precision=3, sci_mode=True) + self.check_configs() + if self.act_bits <= 8 and self.amp_dtype == torch.float16: + logger.warning("force to use bf16 to for quantization tuning when enabling activation quantization") + self.amp_dtype = torch.bfloat16 + self.model = self.model.to(torch.bfloat16) + else: + logger.info(f"using {self.model.dtype} for quantization tuning") + + self.enable_torch_compile = enable_torch_compile + if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \ + and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type: + logger.info("'enable_torch_compile' is set to `False` by default. " \ + "Enabling it can reduce tuning cost by 20%, but it might throw an exception.") + + if self.act_bits <= 8 and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled") + + if self.low_cpu_mem_usage == True and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled") + + if is_debug_mode() and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as debug mode is enabled") + + if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile: + self.enable_torch_compile = False + logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") + + if is_optimum_habana_available(): + logger.info("Optimum Habana is available, import htcore explicitly.") + # import habana_frameworks.torch.core as htcore # pylint: disable=E0401 + # import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] + self.device_map = device_map + + self.set_device_map_in_blocks(self.device_map) + + self.is_packing_immediate = False ## whether to pack the layer immediately after tuning + + self.serialization_keys = [ + "bits", + "group_size", + "sym", + "data_type", + "enable_quanted_input", + "enable_minmax_tuning", + "data_type", + "seqlen", + "batch_size", + "scale_dtype", + "lr", + "minmax_lr", + "gradient_accumulate_steps", + "iters", + "amp", + "nsamples", + "low_gpu_mem_usage", + "to_quant_block_names", + "enable_norm_bias_tuning", + "act_bits", + "act_group_size", + "act_sym", + "act_dynamic", + "act_data_type", + "super_bits", + "super_group_size" + ] + + self.has_qlayer_outside_block = self.set_layerwise_config(self.layer_config) ##better place in the end + self.shared_cache_keys = get_shared_keys(self.model) + + def set_device_map_in_blocks(self, device_map): + """Sets the device map for specific blocks in the model. + + Args: + device_map (Union[str, dict]): A mapping of module names to devices. + If provided as a string, it should be in the format + "module_name:device,module_name:device". Devices can be integers + (GPU IDs) or strings (e.g., 'cpu', 'cuda:0'). + """ + if self.device_map is None or len(self.device_map) == 0: + self.device_map = None + if not device_map: + return + if isinstance(device_map, str): + device_map = device_map.replace(" ", "") + infos = device_map.split(",") + device_map_dict = {} + for info in infos: + index = info.find(':') + key = info[:index] + value = info[index + 1:] + device_map_dict[key] = value + device_map = device_map_dict + + names = [n for n, m in self.model.named_modules() if len(list(m.children())) == 0] + + for key, device in device_map.items(): + if isinstance(device, str) and device.isdigit(): + device = int(device) + device = detect_device(device) + try: + module = get_module(self.model, key) + module.tuning_device = device + except: + matching_names = [name for name in names if re.match(key, name)] + if len(matching_names) > 0: + for name in matching_names: + self._set_device_for_matching_module(name, device) + else: + for name in names: + if key in name: + self._set_device_for_matching_module(name, device) + + def _set_device_for_matching_module(self, name, device): + module = get_module(self.model, name) + if hasattr(module, "tuning_device") and module.tuning_device != device: + logger.warning( + f"Multiple devices have been set for layer {name}, keeping original device {module.tuning_device}") + else: + module.tuning_device = device + + def _dq_check(self): + """Reset the default value of super_bits and super_group_size""" + from auto_round.export.export_to_gguf.config import GGUF_CONFIG + if self.data_type.endswith("_dq"): + gguf_config = GGUF_CONFIG[f"gguf:q{self.bits}_k_s"] + self.super_bits = gguf_config["super_bits"] if self.super_bits is None else self.super_bits + self.super_group_size = gguf_config["super_group_size"] \ + if self.super_group_size is None else self.super_group_size + + def check_configs(self): + + """Checks if the configurations are valid. + + Raises: + AssertionError: If any of the configurations are invalid. + """ + assert isinstance(self.model, torch.nn.Module) + assert self.bits > 0, "bits must be positive" + assert self.act_bits > 0, "bits must be positive" + assert self.group_size == -1 or self.group_size >= 1, "only supports positive group_size or -1(per channel)" + assert self.act_group_size == -1 or self.act_group_size >= 1, \ + "only supports positive group_size or -1(per channel)" + assert self.batch_size > 0, "batch size must be positive" + assert self.iters >= 0, "iters must be non-negative" + assert self.seqlen > 0, "seqlen must be positive" + assert self.nblocks > 0, "nblocks must be positive" + assert self.gradient_accumulate_steps > 0, "gradient accumulate step must be positive" + # assert self.tokenizer != None or self.dataloader != None + if self.act_bits <= 8: + logger.warning( + "activation quantization is an experimental feature with limited support and a complex API. " + "And please save the quantized model to fake format as real deployment is not supported currently") + + if "mx_fp" in self.data_type: + logger.warning( + "please save the quantized model to fake format " + "as real deployment is not supported for mx_fp datatype currently") + + if "mx_fp" in self.data_type and self.group_size != 32: + logger.warning("mx_fp should only support group_size of 32 in real deployment") + + if self.nsamples < self.gradient_accumulate_steps * self.batch_size: + if self.batch_size > self.nsamples: + logger.warning(f"reset batch_size to {self.nsamples} as nsamples({self.nsamples})" + f" is smaller than batch_size({self.batch_size})") + self.batch_size = self.nsamples + if self.gradient_accumulate_steps > self.nsamples // self.batch_size: + self.gradient_accumulate_steps = self.nsamples // self.batch_size + logger.warning( + f"reset gradient_accumulate_steps to {self.gradient_accumulate_steps}" + f" as nsamples must equal or greater" + f" than gradient_accumulate_steps * batch_size") + self._dq_check() + + # def _check_format_compatibility(self, format): ##TODO + # ##check lm_head, mixed_bits, bits, each layer supporting, etc + # pass + + def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "auto_round", inplace=True, **kwargs): + """Quantizes the model and saves it in the specified format(s). + + This function checks the validity of the requested format(s), quantizes + the model accordingly, and saves it to the specified output directory. + If multiple formats are provided, the model is saved separately for each format. + + Args: + output_dir (str, optional): The directory where the quantized model + will be saved. Defaults to "tmp_autoround". + format (str, optional): The quantization format(s) to use, separated + by commas if multiple. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place if only + one format is used. Defaults to True. + **kwargs: Additional arguments for the quantization and saving process. + + Returns: + model: A qdq model or packed model based on the configurations + folders: The folder paths where the quantized models are saved. + + Raises: + ValueError: If an unsupported format is specified. + """ + # Validate and process the specified formats + formats = format.replace(' ', '').split(',') + from auto_round.utils import supported_formats + for format_ in formats: + if format_ not in supported_formats: + logger.error(f"Unsupported format {format_}, please choose from {supported_formats}") + exit(-1) + + # only support to export afp8 + if self.act_bits <= 8: + if "fp8" not in self.act_data_type: + if len(formats) > 1 or "fake" not in formats: + logger.warning( + f"Currently only support to export auto_round format quantized model" + " with fp8 dtype activation for activation quantization." + " Change format to fake and save." + ) + formats = ["fake"] + else: + if len(formats) > 1 or "auto_round" not in formats: + logger.warning( + f"Currently only support to export auto_round format for W{self.bits}AFP8 model," + " change format to auto_round" + ) + formats = ["auto_round"] + + # If multiple formats are specified, enforce inplace=False + if len(formats) > 1: + inplace = False + inplace = kwargs.get("inplace", inplace) + kwargs.pop("inplace", None) + + # Determine if immediate packing is required + if (len(formats) == 1 and + ("awq" in formats[0] or "gptq" in formats[0] or "auto_round" in formats[0]) and + not self.has_qlayer_outside_block and inplace): # TODO: Support more formats + self.is_packing_immediate = True + + # Adjust format settings based on compatibility + for index in range(len(formats)): + format = formats[index] + if "auto_round" in format: + if (self.sym and ("gptq" not in format and "awq" not in format)) or self.bits == 3: + format = format.replace('auto_round', 'auto_round:auto_gptq') + formats[index] = format + + # Remove duplicates from formats list + def remove_duplicates(lst): + seen = set() + return [x for x in lst if not (x in seen or seen.add(x))] + + formats = remove_duplicates(formats) + self.formats = formats + + # # Check format compatibility + # self._check_format_compatibility(formats) + + # Perform model quantization + model, _ = self.quantize() + + # Save the quantized model in the specified formats + folders = [] + for format in formats: + if "gptq" in format and not self.sym: + logger.warning( + "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/4/8 bits)." + ) + save_format_ = format.replace(":", "-").replace("_", "-") + save_folder = os.path.join(output_dir, save_format_) if len(formats) > 1 else output_dir + self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs) + + folders.append(save_folder) + + return model, folders + + @torch.inference_mode + def quantize_rtn(self): + if self.amp: + self.model.to(self.amp_dtype) + self.model.to("cpu") + all_to_quantized_module_names = [] + for n, m in self.model.named_modules(): + if check_to_quantized(m): + all_to_quantized_module_names.append(n) + pbar = tqdm(all_to_quantized_module_names) + + for name in pbar: + pbar.set_description(f"Quantizing {name}") + m = get_module(self.model, name) + + m.to(self.device) + m = WrapperLinear(m, enable_minmax_tuning=False, enable_norm_bias_tuning=False, enable_round_tuning=False) + m = m.unwrapper({}) + m.to("cpu") + if self.low_gpu_mem_usage: + clear_memory() + if self.is_packing_immediate: + from auto_round.export import PACKING_LAYER_WITH_FORMAT + if check_to_quantized(m): + target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0] + PACKING_LAYER_WITH_FORMAT[target_backend](name, self.model, self.formats[0]) + if self.low_gpu_mem_usage: + clear_memory() + else: + set_module(self.model, name, m) + + self.quantized = True + return self.model, self.layer_config + + def quantize(self): + """Quantize the model and return the quantized model along with layer configurations. + the entry of AutoRound. + + Returns: + The quantized model and layer configurations. + """ + if self.iters == 0: + return self.quantize_rtn() + + if bool(self.quant_block_list): + all_blocks = self.quant_block_list + else: + all_blocks = get_block_names(self.model) + + if len(all_blocks) == 0: + logger.warning("could not find blocks, exit with original model") + return self.model, self.layer_config + + if self.amp: + self.model = self.model.to(self.amp_dtype) + + layer_names = self.get_quantized_layer_names_outside_blocks() + self.start_time = time.time() + all_first_block_names = [block[0] for block in all_blocks] + logger.info("start to cache block inputs") + all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules(self.model) ##self.model.hf_device_map has not been changed + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + logger.info("caching done") + pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) + + for block_names in all_blocks: + inputs = all_inputs[block_names[0]] + all_inputs.pop(block_names[0]) + keys = inputs.keys() + input_id_str = [key for key in keys if key.startswith('hidden_state')] + if len(input_id_str) != 1: + raise RuntimeError(f"hidden_states arg mismatch error," + "please raise an issue in https://github.com/intel/auto-round/issues") + inputs["input_ids"] = inputs.pop(input_id_str[0], None) + clear_memory(self.inputs) + + if "input_ids" in inputs.keys(): + total_samples = len(inputs["input_ids"]) + if total_samples < self.batch_size: + self.batch_size = total_samples + logger.warning(f"force the train batch size to {total_samples}") + + self.quant_blocks( + self.model, + inputs, + block_names, + nblocks=self.nblocks, + device=self.device, + pbar=pbar + ) + if self.is_packing_immediate: + assert len(self.formats) == 1 + + self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately + + end_time = time.time() + cost_time = end_time - self.start_time + logger.info(f"quantization tuning time {cost_time}") + + ## dump a summary + quantized_layers = [] + unquantized_layers = [] + for n, m in self.model.named_modules(): + if isinstance(m, tuple(self.supported_types)): + if check_to_quantized(m): + quantized_layers.append(n) + else: + unquantized_layers.append(n) + elif hasattr(m, "scales") or hasattr(m, "scale"): ##packing_immediately + quantized_layers.append(n) + summary_info = ( + f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model" + ) + if len(unquantized_layers) > 0: + summary_info += f", {unquantized_layers} have not been quantized" + logger.info(summary_info) + + self.quantized = True + return self.model, self.layer_config + + def quant_layers(self, layer_names, layer_inputs): + """Quantizes specified layers based on inputs and configuration. + + Args: + layer_names (list): List of layer names to quantize. + layer_inputs (dict): Dictionary mapping layer names to input data. + + Returns: + None + """ + ##TODO currently we take all the layers outside blocks as post block layers which is not optimal + if len(layer_names) == 0: + return + q_layer_inputs = None + enable_quanted_input = self.enable_quanted_input + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1 and enable_quanted_input: + from accelerate.big_modeling import dispatch_model + + dispatch_model(self.model, self.model.hf_device_map) + + if enable_quanted_input: + q_layer_inputs = self.try_cache_inter_data_gpucpu([], self.nsamples, layer_names=layer_names) + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + accelerate.hooks.remove_hook_from_submodules( + self.model) ##self.model.hf_device_map has not been changed + + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + if self.enable_torch_compile: + quant_layer = compile_func(self.quant_layer, self.device) + else: + quant_layer = self.quant_layer + for layer_name in layer_names: + layer_input = layer_inputs[layer_name] + layer_input = to_device(layer_input, self.cache_device) + q_layer_input = q_layer_inputs[layer_name] if enable_quanted_input else None + q_layer_input = to_device(q_layer_input, self.cache_device) + quant_layer(layer_name, layer_input, q_layer_input, device=self.device) + del layer_input + clear_memory(q_layer_input) + + def set_layerwise_config(self, layer_config): + """ + Sets the layer-wise configuration based on the provided `layer_config`. + By default, only quantize layers in blocks. + + Args: + layer_config (dict): The configuration dictionary for each layer containing various configuration options. + + Returns: + bool: Returns True if there are quantized layers outside the blocks (e.g., lm-head), + otherwise returns False. + """ + # Get the names of layers in quantization blocks + layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list) + + ##process regex in layer_config + all_supported_layer_names = [] + # List of configuration keys + keys = self.serialization_keys + + for n, m in self.model.named_modules(): + # Delete previous configuration to avoid conflicts with prior tuning + for key in keys: + if hasattr(m, key): + delattr(m, key) + + # Skip unsupported types + if not isinstance(m, tuple(self.supported_types)): + continue + all_supported_layer_names.append(n) + + names_in_layer_config = list(layer_config.keys()) + for name in names_in_layer_config: + if name in all_supported_layer_names: + continue + matched_names = [] + for layer_name in all_supported_layer_names: + if re.search(re.compile(name), layer_name) is not None: + matched_names.append(layer_name) + if len(matched_names) > 0: + val = layer_config[name] + layer_config.pop(name) + for match_name in matched_names: + layer_config[match_name] = val + else: + raise ValueError(f"key {name} in layer_config is invalid, please have a double check") + + has_qlayer_outside_block = False # Flag to track if there are quantized layers outside blocks (e.g., lm-head) + + # Iterate through all modules in the model + for n, m in self.model.named_modules(): + + # Skip unsupported types + if not isinstance(m, tuple(self.supported_types)): + continue + + # If the layer is not in the config and is part of a quantization block, use default configuration + if n not in layer_config.keys() and n in layers_in_blocks: + layer_config[n] = {} + for key in keys: + layer_config[n][key] = getattr(self, key) + # If the layer is partially configured, fill in missing values + elif n in layer_config.keys(): + for key in keys: + if key not in layer_config[n].keys(): + layer_config[n][key] = getattr(self, key) + # If the layer is not in the config and not part of a quantization block, + # use default configuration and set specific values + else: + layer_config[n] = {} + for key in keys: + layer_config[n][key] = getattr(self, key) + layer_config[n]["bits"] = 16 + layer_config[n]["act_bits"] = 16 + + if n in layers_in_blocks: + layer_config[n]["in_blocks"] = True + else: + layer_config[n]["in_blocks"] = False + + # If the layer is outside a block and requires quantization, mark it as a quantized layer outside the block + if n not in layers_in_blocks and check_to_quantized(layer_config[n]): + has_qlayer_outside_block = True + + in_features, out_features = get_layer_features(m) + if in_features <= layer_config[n]["group_size"]: + layer_config[n]["group_size"] = -1 + + # Apply the configuration to the corresponding layer in the model + for key in keys: + setattr(m, key, layer_config[n][key]) + + # Return whether there are quantized layers outside the blocks + return has_qlayer_outside_block + + @torch.no_grad() + def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_device, save_output=True): + """Compute the output of a given block of the model for a given input. + + Args: + block: The block of the model. + input_ids: The input tensor containing tokenized input ids. + input_others: A dictionary containing additional input data. + bs: The batch size for computing the output. + device: The device for computation. + cache_device: The device for storing the output. + batch_dim: The batch dimension of the output tensor. + + Returns: + The output tensor of the block. + """ + + output = [] + nsamples = len(input_ids) + for i in range(0, nsamples, bs): + end_index = min(nsamples, i + bs) + indices = torch.arange(i, end_index).to(torch.long) + tmp_input_ids, tmp_input_others = AutoRound.sampling_inputs( + input_ids, + input_others, + indices, + self.seqlen, + self.batch_dim, + share_cache_keys=self.shared_cache_keys + ) + tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to( + cache_device + ) + if save_output: + if self.batch_size == 1: + output.append(tmp_output) + else: + output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim))) + if self.low_gpu_mem_usage: + clear_memory() + + return output + + @torch.no_grad() + def calib(self, nsamples, bs): + """Perform calibration for quantization. + + This method calibrates the model for quantization by processing a specified + number of samples from the calibration dataset. It ensures that the data is + properly formatted and feeds it to the model. If the number of samples processed + is less than the specified number, it logs a warning. If no samples are processed, + it logs an error and exits. + Args: + nsamples (int): The number of samples to use for calibration. + bs (int): The number of samples to use for calibration + """ + from .calib_dataset import get_dataloader + if isinstance(self.dataset, str): + dataset = self.dataset.replace(" ", "") ##remove all whitespaces + + # slow here + self.dataloader = get_dataloader( + self.tokenizer, + self.seqlen, + dataset, + self.seed, + bs, + self.nsamples, + ) + else: + self.dataloader = self.dataset + total_cnt = 0 + + # load embed weight if use low_cpu_mem_usage + if self.low_cpu_mem_usage: + embed_layers = get_layers_before_block(self.model) + for n, m in embed_layers: + m = m.to(self.device) + + for data in self.dataloader: + if data is None: + continue + if isinstance(data, torch.Tensor): + input_ids = data.to(self.device) + data_new = input_ids + elif isinstance(data, str): + if self.tokenizer is None: + logger.error("please provide tokenizer for string input") + exit(-1) + data = self.tokenizer(data, truncation=True, max_length=self.seqlen, return_tensors="pt").data + data_new = {} + for key in data.keys(): + data_new[key] = data[key].to(self.device) + input_ids = data_new["input_ids"] + elif isinstance(data, tuple) or isinstance(data, list): + data_new = data + input_ids = data_new[0] + else: + data_new = {} + for key in data.keys(): + data_new[key] = to_device(data[key], self.model.device) + if key == 'images': + data_new[key] = to_dtype(data_new[key], self.model.dtype) + input_ids = data_new["input_ids"] + if input_ids.shape[-1] < self.seqlen: + continue + try: + if isinstance(data_new, torch.Tensor): + self.model(data_new) + elif isinstance(data_new, tuple) or isinstance(data_new, list): + self.model(*data_new) + else: + self.model(**data_new) + except NotImplementedError: + pass + except RuntimeError as error: + logger.warning("When quantization encounters tensor" \ + " shape mismatch error, you can try to avoid it with batch_size=1") + logger.error(error) + pass + except Exception as error: + raise error + total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1 + if total_cnt >= nsamples: + break + if total_cnt == 0: + logger.error( + f"no data has been cached, please provide more data with sequence length >={self.seqlen} in the " + f"dataset or decease the sequence length" + ) + exit(-1) + elif total_cnt < nsamples: + logger.warning( + f"An insufficient number of samples likely reduces the accuracy of the quantized model." + f"Target samples count is {nsamples}, while valid samples count is {total_cnt}" + ) + + # clean embed weight to save memory + if self.low_cpu_mem_usage: + for n, m in embed_layers: + m = m.to("meta") + + @torch.no_grad() + def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, last_cache_name=None): + """Attempts to cache intermediate data on GPU, if failed, then using CPU. + + Args: + block_names (list): List of block names to cache data for. + nsamples (int): Number of samples to use for caching. + layer_names (list, optional): List of layer names to cache data for. Defaults to []. + last_cache_name (str, optional): Name of the last cache. Defaults to None. + + Returns: + all_inputs: Cached intermediate data. + + Raises: + Exception: If caching on GPU fails, switches to CPU and caches there. + """ + if layer_names is None: + layer_names = [] + try: + if not self.model.device.type == "meta": + if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: + pass + else: + self.model = self.model.to(self.device) + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name + ) + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + except RuntimeError as e: + if "CUDA out of memory" in str(e) or "MODULE:PT_DEVMEM" in str(e): + logger.info("switch to cpu to cache block inputs") + if (("lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] < 16) or + self.__class__.__name__ == "AutoRoundMLLM"): + logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. " + f"set `--device '0,1'` in our cmd line usage or " + f"load the model with `device_mapping=auto`," + f" for optimal performance during calibration " + f"Otherwise, the process may be significantly slower.") + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + clear_memory() + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name + ) + else: + raise + return all_inputs + + @torch.no_grad() + def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_name=None): + """Save the inputs of block_name for calibration. + + This method temporarily replaces the forward method of the model to capture + the inputs passing through the specified block. It then calibrates the model + using a specified number of samples. Finally, it restores the original forward + method and returns the inputs for the specified block. + Args: + block_names (list): The names of the blocks for which inputs are to be saved. + layer_names (list):The names of the layers for which inputs are to be saved. + nsamples (int): The number of samples to use for calibration. + last_cache_name (str, optional): The name of the last layer to be cached, + we could break the forward in this layer to save time + + Returns: + dict: A dictionary containing the inputs for the specified block. + """ + if layer_names is None: + layer_names = [] + self.inputs = {} + self.to_cached_layers = block_names + layer_names + tmp_dtype = None + ## have bug if block name is not the first block + if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage: + tmp_dtype = self.model.dtype + self.model = self.model.to(torch.bfloat16) if self.amp else self.model.to(torch.float32) ##model on cpu + + self.last_cache_name = last_cache_name + if last_cache_name is None and len(block_names) + len(layer_names) == 1: + self.last_cache_name = block_names[0] if len(block_names) == 1 else layer_names[0] + # do not set last_cache_name for multimodal models + calib_bs = self.batch_size + self.hook_handles = [] + self._replace_forward() + self.calib(nsamples, calib_bs) + self._recover_forward() + res = self.inputs + del self.last_cache_name + del self.to_cached_layers + if tmp_dtype is not None: + self.model = self.model.to(tmp_dtype) + + return res + + @torch.no_grad() + def get_block_forward_func(self, name): + """Gets the forward function. + + Args: + name (str): The name of the function. + Returns: + function: The forward function. + """ + + def post_process_cache_data(batch_size, data, data_name): + """ + Processes store data for batch handling, reshaping if necessary. + + Args: + batch_size (int): The size of the batch. + data: The data value to store, potentially for caching. + data_name (str): Name of the data. + + Returns: + Processed data or None + """ + new_data = data + if batch_size <= 1: + return new_data + if data_name in self.shared_cache_keys: + return None + if "alibi" in data_name: + if isinstance(data, torch.Tensor): + alibi = data + alibi = alibi.reshape(batch_size, -1, alibi.shape[1], alibi.shape[2]) + new_data = alibi + return new_data + + def forward(m, hidden_states=None, *positional_inputs, **kwargs): + """Rewrite forward function, process and collect input data. + + Args: + hidden_states (torch.Tensor): The hidden states tensor. + *positional_inputs: Variable number of positional arguments. + **kwargs: Variable number of keyword arguments. + + Returns: + NotImplementedError: Getting the first layer inputs and then raise the error to save runtime. + """ + if name not in self.inputs: + self.inputs[name] = {} + init_cache(positional_inputs, self.inputs[name]) + + if self.batch_dim is None: + self.batch_dim = 0 + if hidden_states is not None and self.batch_size > 1: + if hidden_states.shape[0] > self.batch_size: + self.batch_dim = 1 + if len(hidden_states.shape) > 1 and hidden_states.shape[1] > self.batch_size: + logger.error( + f"this model has not been supported, " + f"please raise an issue in https://github.com/intel/auto-round/issues" + f" or try to set the `batch_size` to 1 and " + f"`gradient_accumulate_steps` to your current batch size.") + exit(-1) + + if hidden_states is not None: + kwargs['hidden_states'] = hidden_states + + for key in kwargs.keys(): + if isinstance(kwargs[key], torch.Tensor) or isinstance(kwargs[key], list) \ + or isinstance(kwargs[key], tuple): + if key not in self.inputs[name].keys(): # initialization + data = to_device(kwargs[key], device=torch.device("cpu")) + if data is None or (self.batch_size > 1 and key in self.shared_cache_keys): + self.inputs[name][key] = data + continue + if self.batch_size <= 1: + self.inputs[name][key] = [data] + else: + data = post_process_cache_data(self.batch_size, data, key) + self.inputs[name][key] = list(torch.split(data, 1, dim=self.batch_dim)) + else: # append cache inputs + new_data = post_process_cache_data(self.batch_size, kwargs[key], key) + if new_data is None: # shareable args or NoneType + continue + new_data = to_device(new_data, device=torch.device("cpu")) + if self.batch_size <= 1: + self.inputs[name][key].append(new_data) + else: + self.inputs[name][key].extend(list(torch.split(new_data, 1, dim=self.batch_dim))) + elif isinstance(kwargs[key], (str, bool, type(None))): + if key not in self.inputs[name].keys(): + self.inputs[name][key] = kwargs[key] + else: + # Parameters not to be cached + if check_skippable_keywords(key): + logger.warning_once(f"Please note that '{key}' key" \ + " is not currently used in quantization fine-tuning.") + reset_params(self.inputs[name]) + if name == self.last_cache_name: + raise NotImplementedError + else: + if hidden_states is not None: + kwargs.pop('hidden_states') + return m.orig_forward(hidden_states, *positional_inputs, **kwargs) + else: + # Currently only for Llama-3.2-Vision-Instruct Series + return m.orig_forward(*positional_inputs, **kwargs) + + return forward + + @torch.no_grad() + def _get_cache_data_hook_for_layer(self, name): + """A forward hook to save input max of a module + :param name: the module name + :return: A hook function.""" + + def cache_input_hook(module, inputs, outputs): + input = inputs + if isinstance(inputs, tuple) or isinstance(input, list): + input = inputs[0] + if name in self.inputs: + self.inputs[name].extend(list(torch.split(input.to("cpu"), 1, dim=0))) + else: + self.inputs[name] = list(torch.split(input.to("cpu"), 1, dim=0)) + + return cache_input_hook + + def _recover_forward(self): + """Recovers the forward function.""" + for n, m in self.model.named_modules(): + if hasattr(m, "orig_forward"): + m.forward = m.orig_forward + delattr(m, "orig_forward") + for hook_handle in self.hook_handles: + hook_handle.remove() + self.hook_handles = [] + + def _replace_forward(self): + """Replaces the forward function.""" + from functools import partial + + for n, m in self.model.named_modules(): + if n in self.to_cached_layers and not isinstance(m, tuple(self.supported_types)): ##block + m.orig_forward = m.forward + m.forward = partial(self.get_block_forward_func(n), m) + elif n in self.to_cached_layers: ##linear layer or conv1d layer + hook_func = self._get_cache_data_hook_for_layer(n) + hook_handle = m.register_forward_hook(hook_func) + self.hook_handles.append(hook_handle) + + def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cpu")): + """Quantize a specific layer of the model using the provided inputs. + + Args: + layer_name (str): The name of the layer to quantize. + inputs (torch.Tensor): Input data for quantization. + q_inputs (torch.Tensor, optional): Quantized input data. Defaults to None. + device (torch.device, optional): The device to use for quantization. Defaults to torch.device("cpu"). + + Returns: + None + """ + logger.info(f"quantizing layer {layer_name}") + layer = get_module(self.model, layer_name) + if hasattr(layer, "tuning_device"): + device = layer.tuning_device + + layer = layer.to(device) + for i in range(len(inputs)): + inputs[i] = inputs[i].to(layer.weight.dtype) + if q_inputs is not None: + q_inputs[i] = q_inputs[i].to(layer.weight.dtype) + + wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to( + device) + round_params = [] + minmax_params = [] + for key in wrapper_linear.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(wrapper_linear.params[key]) + else: + round_params.append(wrapper_linear.value) + if self.enable_minmax_tuning: + optimizer = self.optimizer( + [{"params": round_params}, {"params": minmax_params, "lr": self.minmax_lr}], lr=self.lr, weight_decay=0 + ) + else: + optimizer = self.optimizer(round_params, lr=self.lr, weight_decay=0) + + if self.lr_scheduler is None: + lr_schedule = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters + ) + else: + lr_schedule = copy.deepcopy(self.lr_scheduler) + nsamples = len(inputs) + last_best_iter = 0 + best_loss = torch.finfo(torch.float).max + mse_loss = torch.nn.MSELoss().to(device) + scaler = self.get_scaler() # pylint: disable=assignment-from-none + init_loss = None + # best_v, best_min_scale, best_max_scale = torch.tensor(0), torch.tensor(1.0), torch.tensor(1.0) + gradient_accumulate_steps = self.batch_size ##Force to low gpu + batch_size = 1 ##Force to low gpu + pick_samples = batch_size * gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) + if self.sampler != "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + total_loss = 0 + num_elm = 1 + mse_reduction = "mean" + if gradient_accumulate_steps != 1: + mse_reduction = "sum" + mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) + + for i in range(self.iters): + total_loss = 0 + if self.sampler == "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + if gradient_accumulate_steps != 1: + if q_inputs is not None: + current_input = [q_inputs[i] for i in whole_indices] + else: + current_input = [inputs[i] for i in whole_indices] + num_elm = sum(id.numel() for id in current_input) + for tmp_step in range(gradient_accumulate_steps): + indices = whole_indices[tmp_step * batch_size: (tmp_step + 1) * batch_size] + if q_inputs is not None: + current_input = [q_inputs[i] for i in indices] + current_input = torch.cat(current_input, dim=0).to(device) + org_input = [inputs[i] for i in indices] + org_input = torch.cat(org_input, dim=0).to(device) + else: + current_input = [inputs[i] for i in indices] + current_input = torch.cat(current_input, dim=0).to(device) + org_input = current_input + with torch.no_grad(): + current_output = layer(org_input) + + if self.amp: + with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + output_q = wrapper_linear(current_input) # pylint: disable=not-callable + loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + else: + output_q = wrapper_linear(current_input) # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32), current_output.to(torch.float32) + ) + total_loss += loss.item() / num_elm + + self.scale_loss_and_backward(scaler, loss) + if i == 0: + init_loss = total_loss + + if total_loss < best_loss: + best_loss = total_loss + if not self.not_use_best_mse: + best_params = collect_best_params(wrapper_linear) + last_best_iter = i + if self.not_use_best_mse and i == self.iters - 1: + best_params = collect_best_params(wrapper_linear) + + if not self.not_use_best_mse: + if 0 < self.dynamic_max_gap <= i - last_best_iter: + break + self.step(scaler, optimizer, lr_schedule) + + last_loss = total_loss + best_iter = self.iters + if not self.not_use_best_mse: + last_loss = best_loss + best_iter = last_best_iter + with torch.no_grad(): + unwrapper_layer(self.model, wrapper_linear, layer_name, best_params) + mv_module_from_gpu(layer, self.low_cpu_mem_usage) + dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" + logger.info(dump_info) + + def register_act_max_hook(self, model): + def get_act_max_hook(module, input, output): + if isinstance(input, (tuple, list)): + input = input[0] + if not hasattr(module, "act_max"): + module.act_max = torch.abs(input).max().item() + else: + module.act_max = max(torch.abs(input).max().item(), module.act_max) + + hook_handles = [] + + for n, m in model.named_modules(): + if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + return hook_handles + + def quant_block(self, block, input_ids, input_others, q_input=None, device=torch.device("cpu")): + """Quantize the weights of a given block of the model. + + Args: + block: The block of the model to be quantized. + input_ids: The input tensor containing tokenized input ids. + input_others: A dictionary containing additional input data. + q_input: The quantized input tensor. + device: The device for quantization. + + Returns: + Tuple: (q_outputs, output) if self.enable_quanted_input is True, else (None, output) + """ + if self.device_map is not None: + from accelerate import dispatch_model + for n, m in block.named_modules(): + if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): + continue + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + hook = AlignDevicesHook(m.tuning_device, io_same_device=True) + add_hook_to_module(m, hook, True) + + if q_input is None: + hook_handles = self.register_act_max_hook(block) + + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, + device, + self.cache_device) + + for handle in hook_handles: + handle.remove() + else: + output = self.get_block_outputs(block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, + device, + self.cache_device) + hook_handles = self.register_act_max_hook(block) + self.get_block_outputs(block, q_input, input_others, self.batch_size * self.infer_bs_coeff, + device, self.cache_device, save_output=False) + + for handle in hook_handles: + handle.remove() + + if q_input is not None: + if input_ids is not q_input: + clear_memory(input_ids) + else: + clear_memory() + input_ids = q_input + + quantized_layer_names, unquantized_layer_names = wrapper_block( + block, self.enable_minmax_tuning, self.enable_norm_bias_tuning, device=self.device) + + round_params = [] + minmax_params = [] + for n, m in block.named_modules(): + if hasattr(m, "orig_layer"): + for key in m.params.keys(): + if "min" in key or "max" in key: + minmax_params.append(m.params[key]) + else: + round_params.append(m.params[key]) + + if self.enable_minmax_tuning: + optimizer = self.optimizer( + [{"params": round_params}, {"params": minmax_params, "lr": self.minmax_lr}], lr=self.lr, weight_decay=0 + ) + else: + optimizer = self.optimizer(round_params, lr=self.lr, weight_decay=0) + + if len(round_params) + len(minmax_params) <= 0: + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + f"layers in the block" + ) + logger.info(dump_info) + return output, output + + if self.lr_scheduler is None: + lr_schedule = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.0, total_iters=self.iters + ) + else: + lr_schedule = copy.deepcopy(self.lr_scheduler) + + nsamples = len(input_ids) + pick_samples = self.batch_size * self.gradient_accumulate_steps + pick_samples = min(nsamples, pick_samples) + if self.sampler != "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + last_best_iter = 0 + best_loss = torch.finfo(torch.float).max + num_elm = 1 + mse_reduction = "mean" + if self.gradient_accumulate_steps != 1: + mse_reduction = "sum" + mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device) + scaler = self.get_scaler() # pylint: disable=assignment-from-none + init_loss = None + best_params = {} + total_loss = 0 + + for i in range(self.iters): + total_loss = 0 + if self.sampler == "rand": + whole_indices = torch.randperm(nsamples)[:pick_samples] + ##we assume the block input and output shape is same + if self.gradient_accumulate_steps != 1: + current_input_ids = [input_ids[i] for i in whole_indices] + num_elm = sum(id.numel() for id in current_input_ids) + for tmp_step in range(self.gradient_accumulate_steps): + indices = whole_indices[tmp_step * self.batch_size: (tmp_step + 1) * self.batch_size] + current_input_ids, current_input_others = AutoRound.sampling_inputs( + input_ids, + input_others, + indices, + seqlen=self.seqlen, + batch_dim=self.batch_dim, + share_cache_keys=self.shared_cache_keys + ) + + current_output = [output[x] for x in indices] + current_output = torch.cat(current_output, dim=self.batch_dim) + + current_output = to_device(current_output, device) + + output_q = block_forward( + block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device + ) + if self.amp: + with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + else: + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32), current_output.to(torch.float32) + ) + + total_loss += loss.item() / num_elm + self.scale_loss_and_backward(scaler, loss) + + if i == 0: + init_loss = total_loss + + if total_loss < best_loss: + best_loss = total_loss + if not self.not_use_best_mse: + best_params = collect_best_params(block) + # print(f"get better result at iter {i}, the loss is {total_loss}", flush=True) + + last_best_iter = i + if self.not_use_best_mse and i == self.iters - 1: + best_params = collect_best_params(block) + + if not self.not_use_best_mse: + if 0 < self.dynamic_max_gap <= i - last_best_iter: + break + self.step(scaler, optimizer, lr_schedule) + + last_loss = total_loss + best_iter = self.iters + if not self.not_use_best_mse: + last_loss = best_loss + best_iter = last_best_iter + dump_info = ( + f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} " + f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" + ) + logger.info(dump_info) + if len(unquantized_layer_names) != 0: + logger.info(f"{unquantized_layer_names} have not been quantized") + with torch.no_grad(): + unwrapper_block(block, best_params) + if self.enable_quanted_input: + if self.low_cpu_mem_usage: + block = block.to(device) + clear_memory() + q_outputs = self.get_block_outputs( + block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, + cache_device=self.cache_device + ) + if self.device_map is not None: + accelerate.hooks.remove_hook_from_submodules( + block) + mv_module_from_gpu(block, self.low_cpu_mem_usage) + clear_memory(input_ids) + + return q_outputs, output + + else: + if self.device_map is not None: + accelerate.hooks.remove_hook_from_submodules( + block) + mv_module_from_gpu(block, self.low_cpu_mem_usage) + clear_memory(input_ids) + return None, output + + def quant_blocks( + self, + model: torch.nn.Module, + inputs, + block_names, + nblocks=1, + device="cpu", + pbar=None + ): + """Quantize and dequantize the weights of the specified blocks in the model. + + Args: + model: The PyTorch model to be quantized. + inputs: The input data for quantization. + block_names: The names of the blocks to be quantized and dequantized. + nblocks: The number of blocks to quantize and dequantize. + device: The device for quantization and dequantization. + + Returns: + None + """ + q_input = None + clear_memory() + for n, m in model.named_parameters(): + m.requires_grad_(False) + input_ids = inputs["input_ids"] + inputs.pop("input_ids", None) + input_others = inputs + clear_memory() + input_ids = to_device(input_ids, self.cache_device) + input_others = to_device(input_others, self.cache_device) + ## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage + tmp_dtype = self.amp_dtype if self.amp else torch.float32 + for i in range(len(input_ids)): + input_ids[i] = input_ids[i].to(tmp_dtype) + + for key in input_others.keys(): + if isinstance(input_others[key], torch.Tensor) and ( + input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16 + ): + input_others[key] = input_others[key].to(tmp_dtype) + elif isinstance(input_others[key], list): + for i in range(len(input_others[key])): + to_dtype(input_others[key][i], tmp_dtype) + if self.enable_torch_compile: + quant_block = compile_func(self.quant_block, device) + else: + quant_block = self.quant_block + + if pbar is None: + pbar = tqdm(range(0, len(block_names), nblocks)) + + for n, m in self.model.named_modules(): + if isinstance(m, tuple(self.supported_types)): + m.name = n + + for i in range(0, len(block_names), nblocks): + if i != 0: + pbar.update(1) + if nblocks == 1: + n = block_names[i] + pbar.set_description(f"Quantizing {n}") + m = get_module(model, n) + else: + names = block_names[i: min(i + nblocks, len(block_names))] + pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}") + modules = [get_module(model, n) for n in names] + m = WrapperMultiblock(modules) + + if not self.model.device.type == "meta" or self.low_cpu_mem_usage: + m = m.to(device) + + q_input, input_ids = quant_block( + m, + input_ids, + input_others, + q_input=q_input, + device=device, + ) + if self.is_packing_immediate: + from auto_round.export import PACKING_LAYER_WITH_FORMAT + for _, tmp_m in m.named_modules(): + if hasattr(tmp_m, "bits") and check_to_quantized(tmp_m): + target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0] + PACKING_LAYER_WITH_FORMAT[target_backend](tmp_m.name, self.model, self.formats[0]) + pbar.set_description(f"Quantizing done") + pbar.update(1) + pbar.close() + + self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + for n, m in self.model.named_modules(): + if hasattr(m, "name"): + delattr(m, "name") + + del q_input + del input_ids + del input_others + del inputs + + clear_memory() + + def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): + """Save the quantized model to the specified output directory in the specified format. + + Args: + output_dir (str, optional): The directory to save the quantized model. Defaults to None. + format (str, optional): The format in which to save the model. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place. Defaults to True. + **kwargs: Additional keyword arguments specific to the export format. + + Returns: + object: The compressed model object. + """ + # only support to export afp8 + if self.act_bits <= 8: + if "fp8" not in self.act_data_type or self.act_dynamic: + if format != "fake": + logger.warning( + f"Currently only support to export auto_round format quantized model" + " with fp8 dtype activation for activation quantization." + " Change format to fake and save." + ) + format = "fake" + else: + if format != "auto_round": + logger.warning( + f"Currently only support to export auto_round format for static W{self.bits}AFP8 model," + " change format to auto_round" + ) + format = "auto_round" + + if re.search("q\d_k", format) and not self.data_type.endswith("_dq"): + logger.error( + f"datatype<{self.data_type}> not support to export {format} format." + " Please change export format or data_type." + ) + sys.exit(-1) + + if self.low_cpu_mem_usage: + self.model = self.model.to('cpu') + + if not self.quantized: + logger.warning("please run autoround.quantize first") + return + if format == "fake" or format == "qdq": ##TODO fix act quantizaiton later + self.model = self.model.to("cpu") + self.model.save_pretrained(output_dir) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + processor = kwargs.get("processor", None) + if processor is not None: + processor.save_pretrained(output_dir) + return + if self.act_bits <= 8 and format == "qdq": + logger.warning( + "Support for exporting activation quantization is limited. " + "Please ensure that your configuration is supported.") + if format in ["gguf:q4_0", "gguf:q4_1"]: + if self.group_size != 32: + logger.error(f"{format} need group_size=32, but it is {self.group_size}, cannot export.") + return + if format == "gguf:q4_0" and not self.sym: + logger.warning(f"incorrect format choose, will reset to gguf:q4_1") + if format == "gguf:q4_1" and self.sym: + logger.warning(f"incorrect format choose, will reset to gguf:q4_0") + + from auto_round.export import EXPORT_FORMAT + backend = format + format = format.split(":")[0] + if format not in EXPORT_FORMAT: + logger.error(f"export format only supports {EXPORT_FORMAT.keys()}") + raise ValueError(f"export format only supports {EXPORT_FORMAT.keys()}, but got {format}") + save_quantized_as_format = EXPORT_FORMAT.get(format) + if "gptq" in format and not self.sym: + logger.warning( + "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/4/8 bits)." + ) + if "awq" in format and not self.bits == 4: + raise ValueError("The AWQ format only supports W4 quantization ") + + if isinstance(self.dataset, str): + self.serialization_keys.append("dataset") + serialization_dict = {} + for key in self.serialization_keys: + serialization_dict[key] = getattr(self, key) + from .version import __version__ + + serialization_dict["autoround_version"] = __version__ + if "scale_dtype" in serialization_dict.keys(): + serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) + + compressed_model = save_quantized_as_format( ##TODO refine the code + output_dir, + model=self.model, + layer_config=self.layer_config, + inplace=inplace, + bits=self.bits, + group_size=self.group_size, + sym=self.sym, + iters=self.iters, + lr=self.lr, + minmax_lr=self.minmax_lr, + enable_minmax_tuning=self.enable_minmax_tuning, + enable_quanted_input=self.enable_quanted_input, + scale_dtype=self.scale_dtype, + tokenizer=self.tokenizer, + supported_types=self.supported_types, + data_type=self.data_type, + serialization_dict=serialization_dict, + backend=backend, + to_quant_block_names=self.to_quant_block_names, + quant_block_list=self.quant_block_list, + **kwargs + ) + return compressed_model + + def get_quantized_layer_names_outside_blocks(self): + """Gets the names of quantized layers outside blocks in the model. + + Returns: + list: List of layer names outside blocks. + """ + if self.layer_config is None or len(self.layer_config) == 0: + return [] + + layer_names = [] + all_layers_in_block = get_layer_names_in_block(self.model, self.supported_types, self.quant_block_list) + + for key in self.layer_config.keys(): + if key in all_layers_in_block: + continue + layer = get_module(self.model, key) + if layer is None: + logger.error(f"could not find layer {key} in the model, exit...") + exit(-1) + if isinstance(layer, tuple(self.supported_types)) and check_to_quantized(self.layer_config[key]): + layer_names.append(key) + + return layer_names + + def set_amp_dtype(self): + self.amp_dtype = torch.float16 + if self.model.dtype != torch.float32: + self.amp_dtype = self.model.dtype + if self.device == "cpu" or "hpu" in self.device: + self.amp_dtype = torch.bfloat16 + if self.amp: + if self.device == "cpu" and not CpuInfo().bf16: + self.amp = False + self.amp_dtype = torch.float32 + self.model = self.model.to(torch.float32) + logger.warning( + f"amp is set to FALSE as the current {self.device} device does not support the 'bf16' data type." + ) + else: + self.model = self.model.to(self.amp_dtype) + else: + self.amp_dtype = torch.float32 + self.model = self.model.to(torch.float32) + + def get_optimizer(self, optimizer): + """Returns the specified optimizer. In SignRound, we fix the optimizer. + + Args: + optimizer: The optimizer to be used. + + Returns: + The specified optimizer. + """ + from auto_round.sign_sgd import SignSGD + + return SignSGD + + def get_scaler(self): + """Returns scaler, in SignRound, no need to use scaler.""" + return None + + def scale_loss_and_backward(self, scaler, loss): + """Scales the loss and performs backward pass. + + Args: + scaler: The scaler to be used. + loss: The loss to be scaled. + + Returns: + The scaled loss. + """ + scale_loss = loss * 1000 + scale_loss.backward() + if is_optimum_habana_available(): + htcore.mark_step() + return scale_loss + + def step(self, scaler, optimizer, lr_schedule): + """Performs a step in the optimization process. + + Args: + scaler: The scaler to be used. + optimizer: The optimizer for the step. + lr_schedule: The learning rate schedule. + + Returns: + None + """ + optimizer.step() + # for hpu + if is_optimum_habana_available(): + htcore.mark_step() + optimizer.zero_grad() + lr_schedule.step() + + @classmethod + @torch.no_grad() + def sampling_inputs(cls, input_ids, input_others, indices, seqlen, + batch_dim=0, share_cache_keys=()): + """Samples inputs based on the given indices and sequence length. + + Args: + input_ids: The list of input tensor containing input_ids. + input_others: A dictionary containing other input data. + indices: The indices to sample from the input. + seqlen: The sequence length. + + Returns: + current_input_ids: The sampled input IDs. + current_input_others: The sampled other input data. + """ + current_input_ids = [input_ids[i] for i in indices] + + current_input_ids = torch.cat(current_input_ids, dim=batch_dim) + + current_input_others = {"positional_inputs": input_others["positional_inputs"]} + for key in input_others.keys(): + if "positional_inputs" in key: + continue + if (key not in share_cache_keys or len(indices) == 1) \ + and not isinstance(input_others[key], (str, bool, type(None))): + current_input_others[key] = None + if input_others[key] is not None: + current_input_others[key] = [input_others[key][i] for i in indices] + if len(indices) == 1: + current_input_others[key] = current_input_others[key][0] + else: + try: + current_input_others[key] = torch.cat(current_input_others[key], dim=0) + except TypeError as err: + logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") + else: + current_input_others[key] = input_others[key] + + return current_input_ids, current_input_others + + +class AutoRoundOPT(AutoRound): + """Class for automatic rounding-based quantization with optimizers like adamw of a PyTorch model. + + Args: + model: The PyTorch model to be quantized. + tokenizer: An optional tokenizer for processing input data. + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether sym to be used (default is True). + layer_config (dict): Configuration for weight quantization (default is None). + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for training (default is "auto"). + lr_scheduler: The learning rate scheduler to be used. + dataset: The default dataset name (default is "NeelNanda/pile-10k"). + enable_quanted_input (bool): Whether to use quantized input data (default is True). + enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). + lr (float): The learning rate (default is 0.005). + minmax_lr (float): The learning rate for min-max tuning (default is None). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). + low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). + iters (int): Number of iterations (default is 200). + seqlen (int): Length of the sequence. + nsamples (int): Number of samples (default is 128). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + nblocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + have different choices. + act_bits (int): Number of bits for activation quantization. Default is 16. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. + to_quant_block_names (str|list): A string or list whose elements are list of + block's layer names to be quantized. + enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning + enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer function + **kwargs: Additional keyword arguments. + + Returns: + The quantized model. + """ + + def __init__( + self, + model, + tokenizer=None, + bits: int = 4, + group_size: int = 128, + sym: bool = True, + layer_config=None, + batch_size: int = 8, + amp: bool = True, + device=None, + lr_scheduler=None, + dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", + enable_quanted_input: bool = True, + enable_minmax_tuning: bool = True, + lr: float = None, + minmax_lr: float = None, + low_gpu_mem_usage: bool = False, + low_cpu_mem_usage: bool = False, + iters: int = 200, + seqlen: int = 2048, + nsamples: int = 128, + sampler: str = "rand", + seed: int = 42, + nblocks: int = 1, + gradient_accumulate_steps: int = 1, + not_use_best_mse: bool = False, + dynamic_max_gap: int = -1, + data_type: str = "int", + scale_dtype: str = "fp16", + act_bits: int = 16, + act_group_size: int = None, + act_sym: bool = None, + act_data_type: str = None, + act_dynamic: bool = True, + to_quant_block_names: Union[str, list] = None, + enable_norm_bias_tuning: bool = False, + enable_torch_compile: bool = False, + device_map: Union[str, dict] = None, + optimizer="AdamW", + super_bits: int = None, + super_group_size: int = None, + **kwargs, + ): + super(AutoRoundOPT, self).__init__( + model=model, + tokenizer=tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + layer_config=layer_config, + batch_size=batch_size, + amp=amp, + device=device, + lr_scheduler=lr_scheduler, + dataset=dataset, + enable_quanted_input=enable_quanted_input, + enable_minmax_tuning=enable_minmax_tuning, + lr=lr, + minmax_lr=minmax_lr, + low_gpu_mem_usage=low_gpu_mem_usage, + low_cpu_mem_usage=low_cpu_mem_usage, + iters=iters, + seqlen=seqlen, + nsamples=nsamples, + sampler=sampler, + seed=seed, + nblocks=nblocks, + gradient_accumulate_steps=gradient_accumulate_steps, + not_use_best_mse=not_use_best_mse, + dynamic_max_gap=dynamic_max_gap, + data_type=data_type, + scale_dtype=scale_dtype, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + act_data_type=act_data_type, + act_dynamic=act_dynamic, + to_quant_block_names=to_quant_block_names, + enable_norm_bias_tuning=enable_norm_bias_tuning, + enable_torch_compile=enable_torch_compile, + device_map=device_map, + super_bits=super_bits, + super_group_size=super_group_size, + **kwargs, + ) + + self.optimizer = self.get_optimizer(optimizer) + + def get_optimizer(self, optimizer): + if optimizer is None: + optimizer = torch.optim.AdamW + elif isinstance(optimizer, str): + optimizer = getattr(torch.optim, optimizer) + else: + optimizer = optimizer + return optimizer + + def get_scaler(self): + scaler = None + if self.amp and not check_is_cpu(self.device): + from torch.cuda.amp import GradScaler + + scaler = GradScaler(init_scale=1024, growth_interval=100000) + return scaler + + def scale_loss_and_backward(self, scaler, loss): + if scaler is not None: + loss = scaler.scale(loss) + + loss.backward() + if is_optimum_habana_available(): + htcore.mark_step() + return loss + + def step(self, scaler, optimizer, lr_schedule): + if scaler is not None: + scaler.step(optimizer) + optimizer.zero_grad() + lr_schedule.step() + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + lr_schedule.step() + if is_optimum_habana_available(): + htcore.mark_step() + + +class AutoRoundAdam(AutoRoundOPT): + """Class for automatic rounding-based quantization with optimizers like adamw of a PyTorch model. + The default lr has been changed. + + Args: + model: The PyTorch model to be quantized. + tokenizer: An optional tokenizer for processing input data. + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (str): Whether symmetric quantization to be used (default is True). + layer_config (dict): Configuration for weight quantization (default is None). + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for training (default is "auto"). + lr_scheduler: The learning rate scheduler to be used. + dataset (Union[str, list, tuple, torch.utils.data.DataLoader]): + The default dataset name (default is "NeelNanda/pile-10k"). + enable_quanted_input (bool): Whether to use quantized input data (default is True). + enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). + lr (float): The learning rate (default is 0.005). + minmax_lr (float): The learning rate for min-max tuning (default is None). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). + low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). + iters (int): Number of iterations (default is 200). + seqlen (int): Length of the sequence. + nsamples (int): Number of samples (default is 128). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + nblocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + optimizer: string or object + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + have different choices. + act_bits (int): Number of bits for activation quantization. Default is 16. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_data_type (str): Specifies the data type for activations. + Defaults to None, in which case it inherits the weight data type. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. + to_quant_block_names (str|list): A list whose elements are list of block's layer names to be quantized. + enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning + enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer function + Returns: + The quantized model. + """ + + def __init__( + self, + model, + tokenizer=None, + bits: int = 4, + group_size: int = 128, + sym: bool = True, + layer_config=None, + batch_size: int = 8, + amp: bool = True, + device=None, + lr_scheduler=None, + dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", + enable_quanted_input: bool = True, + enable_minmax_tuning: bool = True, + lr: float = None, + minmax_lr: float = None, + low_gpu_mem_usage: bool = False, + low_cpu_mem_usage: bool = False, + iters: int = 200, + seqlen: int = 2048, + nsamples: int = 128, + sampler: str = "rand", + seed: int = 42, + nblocks: int = 1, + gradient_accumulate_steps: int = 1, + not_use_best_mse: bool = False, + dynamic_max_gap: int = -1, + data_type: str = "int", + scale_dtype: str = "fp16", + act_bits: int = 16, + act_group_size: int = None, + act_sym: bool = None, + act_data_type: str = None, + act_dynamic: bool = True, + to_quant_block_names: Union[str, list] = None, + enable_norm_bias_tuning: bool = False, + enable_torch_compile: bool = False, + device_map: Union[str, dict] = None, + optimizer="AdamW", + super_bits: int = None, + super_group_size: int = None, + **kwargs, + ): + super(AutoRoundAdam, self).__init__( + model=model, + tokenizer=tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + layer_config=layer_config, + batch_size=batch_size, + amp=amp, + device=device, + lr_scheduler=lr_scheduler, + dataset=dataset, + enable_quanted_input=enable_quanted_input, + enable_minmax_tuning=enable_minmax_tuning, + lr=lr, + minmax_lr=minmax_lr, + low_gpu_mem_usage=low_gpu_mem_usage, + low_cpu_mem_usage=low_cpu_mem_usage, + iters=iters, + seqlen=seqlen, + nsamples=nsamples, + sampler=sampler, + seed=seed, + nblocks=nblocks, + gradient_accumulate_steps=gradient_accumulate_steps, + not_use_best_mse=not_use_best_mse, + dynamic_max_gap=dynamic_max_gap, + data_type=data_type, + scale_dtype=scale_dtype, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + act_data_type=act_data_type, + act_dynamic=act_dynamic, + to_quant_block_names=to_quant_block_names, + enable_norm_bias_tuning=enable_norm_bias_tuning, + enable_torch_compile=enable_torch_compile, + device_map=device_map, + optimizer=optimizer, + super_bits=super_bits, + super_group_size=super_group_size, + **kwargs, + ) diff --git a/auto_round_diff/calib_dataset.py b/auto_round_diff/calib_dataset.py new file mode 100644 index 00000000..33d16c99 --- /dev/null +++ b/auto_round_diff/calib_dataset.py @@ -0,0 +1,698 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import random + +import torch +from datasets import Dataset, IterableDataset +from datasets import Features, Sequence, Value +from torch.utils.data import DataLoader +import sys +from .utils import is_local_path, logger + +CALIB_DATASETS = {} + + +def register_dataset(name): + """Class decorator to register a DATASET subclass to the registry. + + Decorator function used before a Pattern subclass. + + Args: + name: A string. Define the dataset type. + + Returns: + cls: The class of register. + """ + + def register(dataset): + CALIB_DATASETS[name] = dataset + return dataset + + return register + + +def apply_chat_template_to_samples(samples, tokenizer, seqlen, system_prompt=None): + rendered_messages = [] + if system_prompt is None: + system_prompt = "You are a helpful assistant." + for text in samples: + if system_prompt == "": + message = [{"role": "user", "content": text}] + else: + message = [{"role": "system", "content": system_prompt}, + {"role": "user", "content": text}] + try: + chat_templated = tokenizer.apply_chat_template( + message, + tokenize=False, + add_generation_prompt=True, + + ) + except: + logger.warning( + "Failed to apply chat template. removing the system role in chat history." + ) + message_modified = [msg for msg in message if msg["role"] != "system"] + chat_templated = tokenizer.apply_chat_template( + message_modified, + tokenize=False, + add_generation_prompt=True, + + ) + + rendered_messages.append(chat_templated) + example = tokenizer(rendered_messages, truncation=True, max_length=seqlen) + return example + + +def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=False, system_prompt=None): + """Returns a default tokenizer function. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length of + seqlen to the "text" field of examples. + """ + + def default_tokenizer_function(examples): + if not apply_chat_template: + example = tokenizer(examples["text"], truncation=True, max_length=seqlen) + else: + example = apply_chat_template_to_samples(examples["text"], tokenizer, seqlen, system_prompt) + return example + + return default_tokenizer_function + + +@register_dataset("NeelNanda/pile-10k") +def get_pile_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", split=None, seed=42, + apply_chat_template=False, system_prompt=None): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + from datasets import load_dataset + + split = "train" + + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template, + system_prompt=system_prompt) + try: + calib_dataset = load_dataset(dataset_name, split=split) + except Exception as e: + import ssl + error_message = str(e) + # Check for proxy or SSL error + if "proxy" in error_message.lower() or isinstance(e, ssl.SSLError) or "SSL" in error_message.upper(): + logger.error(f"Network error detected, please checking proxy settings." \ + "Error: {error_message}. Or consider using a backup dataset by `pip install modelscope`" \ + " and set '--dataset swift/pile-val-backup' in AutoRound API.") + else: + logger.error(f"Failed to load the dataset: {error_message}") + sys.exit(1) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + +@register_dataset("swift/pile-val-backup") +def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup", split=None, seed=42, + apply_chat_template=False, system_prompt=None): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test", "validation"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + + split = "validation" + + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template, + system_prompt=system_prompt) + from transformers.utils.versions import require_version + require_version("modelscope", "Loading 'swift/pile-val-backup' dataset requires modelscope to be installed, " \ + "`pip install modelscope`") + from modelscope import MsDataset # pylint: disable=E0401 + calib_dataset = MsDataset.load('swift/pile-val-backup', + 'default', split=split).to_iterable_dataset() # , use_streaming=True + calib_dataset = calib_dataset.take(10000) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + +@register_dataset("BAAI/CCI3-HQ") +def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False, + system_prompt=None): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + from datasets import load_dataset + + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template, + system_prompt=system_prompt) + + calib_dataset = load_dataset(dataset_name, split='train', streaming=True) + calib_dataset = calib_dataset.take(10000) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + +@register_dataset("codeparrot/github-code-clean") +def get_github_code_clean_dataset(tokenizer, seqlen, dataset_name="codeparrot/github-code-clean", split=None, seed=42, + apply_chat_template=False, system_prompt=None): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + + def get_default_tokenizer_function(): + """Returns a default tokenizer function. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length + of seqlen to the "code" field of examples. + """ + + def default_tokenizer_function(examples): + if not apply_chat_template: + example = tokenizer(examples["code"], truncation=True, max_length=seqlen) + else: + example = apply_chat_template_to_samples(examples["code"], tokenizer, seqlen, + system_prompt=system_prompt) + return example + + return default_tokenizer_function + + from datasets import load_dataset + + tokenizer_function = get_default_tokenizer_function() + + calib_dataset = load_dataset(dataset_name, split='train', streaming=True) + calib_dataset = calib_dataset.take(10000) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + +@register_dataset("madao33/new-title-chinese") +def get_new_chinese_title_dataset( + tokenizer, + seqlen, + dataset_name="madao33/new-title-chinese", + split=None, + seed=42, + apply_chat_template=False, + system_prompt=None +): + """ + Returns a tokenized dataset for the specified parameters. + + Args: + tokenizer: The tokenizer to use. + seqlen: Maximum sequence length. + dataset_name: Name of the dataset to load. + split: Which split of the dataset to use. + seed: Random seed for shuffling. + apply_template: Whether to apply a template to the data. + + Returns: + A tokenized and shuffled dataset. + """ + + def get_tokenizer_function(): + """Returns a default tokenizer function. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: A default tokenizer function that applies the provided tokenizer with truncation and a maximum length + of seqlen to the "text" field of examples. + """ + + def default_tokenizer_function(examples): + if not apply_chat_template: + example = tokenizer(examples["content"], truncation=True, max_length=seqlen) + else: + example = apply_chat_template_to_samples(examples["content"], tokenizer, seqlen, + system_prompt=system_prompt) + return example + + return default_tokenizer_function + + split = "train" + from datasets import load_dataset + + tokenizer_function = get_tokenizer_function() + + calib_dataset = load_dataset(dataset_name, split=split) + calib_dataset = calib_dataset.shuffle(seed=seed) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + +@register_dataset("mbpp") +def get_mbpp_dataset(tokenizer, seqlen, dataset_name="mbpp", split=None, seed=42, apply_chat_template=False, + system_prompt=None): + """Returns a dataloader for the specified dataset and split. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name of the dataset. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for the specified dataset and split, using the provided tokenizer and sequence length. + """ + from datasets import load_dataset + + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template, + system_prompt=system_prompt) + + samples = [] + splits = split + if splits is None: + splits = ["train", "validation", "test"] + if isinstance(splits, str): + splits = splits.split("+") + + for split in splits: + dataset = load_dataset(dataset_name, split=split) + for data in dataset: + samples.append({"text": data["text"] + data["code"]}) + random.Random(seed).shuffle(samples) + import datasets + + calib_dataset = datasets.Dataset.from_list(samples) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + + return calib_dataset + + +@register_dataset("local") +def get_local_dataset(tokenizer, seqlen, dataset_name="./tmp.json", split=None, seed=42, apply_chat_template=False, + system_prompt=None): + """Returns a dataloader for a custom dataset and split. + We allow the input of a json or text file containing a processed text sample each line. + + Args: + tokenizer: The tokenizer to be used for tokenization. + seqlen: The maximum sequence length. + data_name: The name or path of the dataset, which is a json or jsonl file. + split: The data split to be used (e.g., "train", "test"). + seed: The random seed for shuffling the dataset. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + A dataloader for a custom dataset and split, using the provided tokenizer and sequence length. + """ + tokenizer_function = get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template, + system_prompt=system_prompt) + + def load_local_data(data_path): + if data_path.endswith(".json"): + with open(data_path, "r") as f: + data = json.load(f) + return data + elif data_path.endswith(".jsonl"): + data = [] + with open(data_path) as f: + for line in f: + sample = json.loads(line) + data.append(sample) + return data + else: + logger.error("invalid local file type, for now only support json/jsonl format data file.") + + samples = [] + dataset = load_local_data(dataset_name) + if isinstance(dataset, dict): + new_dataset = [] + for key in dataset.keys(): + new_dataset.append(dataset[key]) + dataset = new_dataset + for data in dataset: + text = data + if isinstance(text, str): + pass + elif isinstance(data, dict) and len(data.keys()) == 1: + for item in data.items(): + text = item[1] + elif isinstance(data, dict) and "text" in data.keys(): + text = data["text"] + elif isinstance(data, dict) and "input_ids" in data.keys(): + text = data["input_ids"] + assert isinstance(text, str), "data must be string" + text = text.rstrip() + text = text.rstrip("\n") + samples.append({"text": text}) + random.Random(seed).shuffle(samples) + import datasets + + calib_dataset = datasets.Dataset.from_list(samples) + calib_dataset = calib_dataset.map(tokenizer_function, batched=True) + return calib_dataset + + +def get_dataset_len(dataset): + """Calculates the length of a dataset. + + Args: + dataset: The dataset object, which can be any iterable or collection. + + Returns: + int: The length of the dataset. + + Raises: + If the dataset does not support `len()`, iterates through it to count the number of elements. + """ + try: + dataset_len = len(dataset) + return dataset_len + except: + cnt = 0 + for _ in dataset: + cnt += 1 + return cnt + + +def select(dataset, indices): + """Selects specific elements from a dataset based on given indices. + + Args: + dataset: The dataset object to iterate over. + indices: An iterable of integers specifying the indices to select. + + Yields: + Elements of the dataset corresponding to the specified indices. + + Notes: + Stops iterating once the highest index in `indices` has been processed + to optimize performance. + """ + indices = set(indices) + for idx, sample in enumerate(dataset): + if idx in indices: + yield sample + if idx > max(indices): + break + + +def select_dataset(dataset, indices): + """Selects elements from a dataset using its native `select` method, if available. + + Args: + dataset: The dataset object, which may have a `select` method. + indices: An iterable of integers specifying the indices to select. + + Returns: + A subset of the dataset, either using the dataset's `select` method or the + `select` function defined above as a fallback. + """ + try: + return dataset.select(indices) + except: + list_data = list(select(dataset, indices)) + import pandas as pd + df = pd.DataFrame(list_data) + dataset = Dataset.from_pandas(df) + return dataset + + +def get_dataloader( + tokenizer, + seqlen, + dataset_name="NeelNanda/pile-10k", + seed=42, + bs=8, + nsamples=512, +): + """Generate a DataLoader for calibration using specified parameters. + + Args: + tokenizer (Tokenizer): The tokenizer to use for tokenization. + seqlen (int): The exact sequence length. samples < seqlen will be dropped, + samples longer than seqlen will be truncated + dataset_name (str, optional): The name of the dataset or datasets separated by commas. + Defaults to "NeelNanda/pile-10k". + split (str, optional): The data split to use. Defaults to None. + seed (int, optional): The random seed for reproducibility. Defaults to 42. + bs (int, optional): The batch size. Defaults to 4. + nsamples (int, optional): The total number of samples to include. Defaults to 512. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + DataLoader: The DataLoader for the calibrated dataset. + """ + + dataset_names = dataset_name.split(",") + + def filter_func(example): + if isinstance(example["input_ids"], list): + example["input_ids"] = torch.tensor(example["input_ids"]) + if example["input_ids"].shape[-1] < seqlen: + return False + input_ids = example["input_ids"][:seqlen] + input_ids_list = input_ids.tolist() + if len(input_ids_list) > 1 and seqlen > 2 and input_ids_list.count(input_ids_list[-1]) > seqlen // 2: + return False + return True + + def concat_dataset_element(dataset): + input_ids, concat_input_ids = [eg['input_ids'] for eg in dataset], [] + attention_mask_list, attention_mask = [], torch.ones([1, seqlen]).to(torch.int64) + buffer_input_id = torch.Tensor().to(torch.int64) + bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id + os_cnt, have_bos, have_eos = 0, False, False + + for input_id in input_ids: + if input_id[0] == bos_token_id: + input_id = input_id[1:] + os_cnt, have_bos = os_cnt + 1, True + if input_id[-1] == eos_token_id: + input_id = input_id[:-1] + os_cnt, have_eos = os_cnt + 1, True + + if buffer_input_id.shape[-1] + input_id.shape[-1] + os_cnt > seqlen: + idx_keep = seqlen - buffer_input_id.shape[-1] - os_cnt + input_id_to_append = [buffer_input_id, input_id[:idx_keep]] + if have_bos: + input_id_to_append = [torch.tensor([bos_token_id])] + input_id_to_append + if have_eos: + input_id_to_append.append(torch.tensor([eos_token_id])) + + concat_input_ids.append(torch.cat(input_id_to_append).to(torch.int64)) + attention_mask_list.append(attention_mask) + buffer_input_id = input_id[idx_keep:] + else: + buffer_input_id = torch.cat([buffer_input_id, input_id]) + + if buffer_input_id.shape[-1] + os_cnt == seqlen: + input_id_to_append = [buffer_input_id] + if have_bos: + input_id_to_append = [torch.tensor([bos_token_id])] + input_id_to_append + if have_eos: + input_id_to_append.append(torch.tensor([eos_token_id])) + concat_input_ids.append(torch.cat(input_id_to_append).to(torch.int64)) + attention_mask_list.append(attention_mask) + buffer_input_id = torch.Tensor().to(torch.int64) + data = [{'input_ids': a, 'attention_mask': b} for a, b in zip(concat_input_ids, attention_mask_list)] + import datasets + dataset_new = datasets.Dataset.from_list(data) + return dataset_new + + datasets, data_lens = [], {} + system_prompt = "You are a helpful assistant." + for name in dataset_names: + split = None + do_concat = False + apply_chat_template = False + + if ":" in name: + name, split_list = name.split(":")[0], name.split(":")[1:] + for ele in split_list: + key, values = ele.split('=')[0], ele.split('=')[1:] + if key == "split": + split = values[0].split('+') + if key == "num": + data_lens[name] = int(values[0]) + if key == "concat": + do_concat = False if (len(values) > 0 and values[0].lower() == 'false') else True + if key == "apply_chat_template": + apply_chat_template = False if (len(values) > 0 and values[0].lower() == 'false') else True + if key == "system_prompt": + system_prompt = values[0] + apply_chat_template = True + if is_local_path(name): + get_dataset = CALIB_DATASETS.get("local") + else: + calib_name = name + if name not in CALIB_DATASETS.keys(): + calib_name = name.split('/')[-1] + for key in CALIB_DATASETS.keys(): + if calib_name in key: + calib_name = key + break + get_dataset = CALIB_DATASETS.get(calib_name) + dataset = get_dataset( + tokenizer, + seqlen, + seed=seed, + split=split, + dataset_name=name, + apply_chat_template=apply_chat_template, + system_prompt=system_prompt + ) + if do_concat: + dataset = concat_dataset_element(dataset) + dataset = dataset.filter(filter_func) + if name in data_lens: + dataset = select_dataset(dataset, range(data_lens[name])) + if isinstance(dataset, IterableDataset): + dataset = Dataset.from_list(list(dataset)) + dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) + new_features = {} + for k, v in dataset.features.items(): + if k == "input_ids": + new_features[k] = Sequence(Value('int64')) + elif k == "attention_mask": + new_features[k] = Sequence(Value('int8')) + else: + new_features[k] = v + + dataset = dataset.cast(Features(new_features)) + datasets.append(dataset) + + if len(datasets) == 1: + dataset_final = datasets[0] + else: + indices = range(len(datasets)) + lens = [] + for i in range(len(datasets)): + cnt = get_dataset_len(datasets[i]) + lens.append(cnt) + res = sorted(zip(indices, lens), key=lambda x: x[1]) + + # res = sorted(zip(indices, datasets), key=lambda x: len(x[1])) + indices = [item[0] for item in res] + datasets = [datasets[item[0]] for item in res] + dataset_names = [dataset_names[index] for index in indices] + cnt = 0 if not data_lens else sum(data_lens.values()) + dataset_cnt_info = {} + if cnt > nsamples: + cnt = 0 + + for i in range(len(datasets)): + name = dataset_names[i].split(':')[0] + if name not in data_lens: + target_cnt = (nsamples - cnt) // (len(datasets) - len(data_lens)) if data_lens \ + else (nsamples - cnt) // (len(datasets) - i) + target_cnt = min(target_cnt, lens[i]) + cnt += target_cnt + else: + target_cnt = data_lens[name] + datasets[i] = select_dataset(datasets[i], range(target_cnt)) + dataset_cnt_info[name] = target_cnt + if len(datasets) > 1: + from datasets import concatenate_datasets + + dataset_final = concatenate_datasets(datasets) + dataset_final = dataset_final.shuffle(seed=seed) + logger.info(dataset_cnt_info) + else: + dataset_final = datasets[0] + + # dataset_final = datasets[0] + + @torch.no_grad() + def collate_batch(batch): + input_ids_new = [] + attention_mask_new = [] + for text in batch: + input_ids, attention_mask = text["input_ids"], text["attention_mask"] + if isinstance(input_ids, list): + input_ids = torch.tensor(input_ids) + if isinstance(attention_mask, list): + attention_mask = torch.tensor(attention_mask) + input_ids = input_ids[:seqlen] + input_ids_list = input_ids.tolist() + if input_ids_list.count(input_ids_list[-1]) > seqlen // 2: + continue + attention_mask = attention_mask[:seqlen] + attention_mask_new.append(attention_mask) + input_ids_new.append(input_ids) + if len(input_ids_new) == 0: + return None + input_ids_new = torch.vstack(input_ids_new) + attention_mask_new = torch.vstack(attention_mask_new) + res = {"input_ids": input_ids_new, "attention_mask": attention_mask_new} + return res + + calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch) + return calib_dataloader diff --git a/auto_round_diff/collect_calibrations.py b/auto_round_diff/collect_calibrations.py new file mode 100644 index 00000000..6d38555e --- /dev/null +++ b/auto_round_diff/collect_calibrations.py @@ -0,0 +1,467 @@ +## To collect input data, remember to uncomment line 987-988 in ldm/models/diffusion/ddpm.py and comment them after finish collecting. +import argparse, os, datetime, gc, yaml +import logging +import cv2 +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +import torch +import torch.nn as nn +from torch import autocast +from contextlib import nullcontext + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from qdiff import ( + QuantModel, QuantModule, BaseQuantBlock, + block_reconstruction, layer_reconstruction, +) +from qdiff.adaptive_rounding import AdaRoundQuantizer +from qdiff.quant_layer import UniformAffineQuantizer +from qdiff.utils import resume_cali_model, get_train_samples +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor + +os.environ['CUDA_VISIBLE_DEVICES']='1' + +logger = logging.getLogger(__name__) + +# load safety model +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def load_model_from_config(config, ckpt, verbose=False): + logging.info(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + logging.info(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + logging.info("missing keys:") + logging.info(m) + if len(u) > 0 and verbose: + logging.info("unexpected keys:") + logging.info(u) + + model.cuda() + model.eval() + return model + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--laion400m", + action='store_true', + help="uses the LAION400M model", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=4, + help="how many samples to produce for each given prompt. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + # linear quantization configs + parser.add_argument( + "--ptq", action="store_true", help="apply post-training quantization" + ) + parser.add_argument( + "--quant_act", action="store_true", + help="if to quantize activations when ptq==True" + ) + parser.add_argument( + "--weight_bit", + type=int, + default=8, + help="int bit for weight quantization", + ) + parser.add_argument( + "--act_bit", + type=int, + default=8, + help="int bit for activation quantization", + ) + parser.add_argument( + "--quant_mode", type=str, default="symmetric", + choices=["linear", "squant", "qdiff"], + help="quantization mode to use" + ) + + # qdiff specific configs + parser.add_argument( + "--cali_st", type=int, default=1, + help="number of timesteps used for calibration" + ) + parser.add_argument( + "--cali_batch_size", type=int, default=32, + help="batch size for qdiff reconstruction" + ) + parser.add_argument( + "--cali_n", type=int, default=1024, + help="number of samples for each timestep for qdiff reconstruction" + ) + parser.add_argument( + "--cali_iters", type=int, default=20000, + help="number of iterations for each qdiff reconstruction" + ) + parser.add_argument('--cali_iters_a', default=5000, type=int, + help='number of iteration for LSQ') + parser.add_argument('--cali_lr', default=4e-4, type=float, + help='learning rate for LSQ') + parser.add_argument('--cali_p', default=2.4, type=float, + help='L_p norm minimization for LSQ') + parser.add_argument( + "--cali_ckpt", type=str, + help="path for calibrated model ckpt" + ) + parser.add_argument( + "--cali_data_path", type=str, default="sd_coco_sample1024_allst.pt", + help="calibration dataset name" + ) + parser.add_argument( + "--resume", action="store_true", + help="resume the calibrated qdiff model" + ) + parser.add_argument( + "--resume_w", action="store_true", + help="resume the calibrated qdiff model weights only" + ) + parser.add_argument( + "--cond", action="store_true", + help="whether to use conditional guidance" + ) + parser.add_argument( + "--no_grad_ckpt", action="store_true", + help="disable gradient checkpointing" + ) + parser.add_argument( + "--split", action="store_true", + help="use split strategy in skip connection" + ) + parser.add_argument( + "--running_stat", action="store_true", + help="use running statistics for act quantizers" + ) + parser.add_argument( + "--rs_sm_only", action="store_true", + help="use running statistics only for softmax act quantizers" + ) + parser.add_argument( + "--sm_abit",type=int, default=8, + help="attn softmax activation bit" + ) + parser.add_argument( + "--verbose", action="store_true", + help="print out info like quantized model arch" + ) + opt = parser.parse_args() + + if opt.laion400m: + print("Falling back to LAION 400M model...") + opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" + opt.ckpt = "models/ldm/text2img-large/model.ckpt" + opt.outdir = "outputs/txt2img-samples-laion400m" + + seed_everything(opt.seed) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = os.path.join(opt.outdir, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) + os.makedirs(outpath) + + log_path = os.path.join(outpath, "run.log") + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO, + handlers=[ + logging.FileHandler(log_path), + logging.StreamHandler() + ] + ) + logger = logging.getLogger(__name__) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + logging.info("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "StableDiffusionV1" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + logging.info(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image = x_samples_ddim + # x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + if not opt.skip_save: + for x_sample in x_checked_image_torch: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + + if not opt.skip_grid: + all_samples.append(x_checked_image_torch) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + import ldm.globalvar as globalvar + input_list = globalvar.getInputList() + torch.save(input_list, "/home/qianzeng/q-diffusion/reproduce/sd-v1-4/w4a8/cali_data/parti_prompts/imagenet_input_{}steps.pth".format(opt.ddim_steps)) + + logging.info(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/auto_round_diff/data_type/__init__.py b/auto_round_diff/data_type/__init__.py new file mode 100644 index 00000000..d6e55ac6 --- /dev/null +++ b/auto_round_diff/data_type/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import auto_round_diff.data_type.int +import auto_round_diff.data_type.mxfp +import auto_round_diff.data_type.fp8 +from auto_round_diff.data_type.register import QUANT_FUNC_WITH_DTYPE, QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE +import auto_round_diff.data_type.w4fp8 +from auto_round_diff.data_type.utils import get_quant_func, get_static_quant_func +import auto_round_diff.data_type.nvfp + diff --git a/auto_round_diff/data_type/fp8.py b/auto_round_diff/data_type/fp8.py new file mode 100644 index 00000000..7f7bb93a --- /dev/null +++ b/auto_round_diff/data_type/fp8.py @@ -0,0 +1,192 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from auto_round_diff.data_type.utils import get_gaudi_fp8_ste_func, float8_e4m3fn_ste +from auto_round_diff.data_type.register import register_dtype, register_dtype_static + +@register_dtype("fp8_dynamic_per_token_sym") +def fp8_dynamic_per_token_sym(tensor, max_scale=1.0, **kwargs): + """Dynamic per-token symmetric quantization using float8. + + This function dynamically calculates a per-token scaling factor for each group of tokens + and applies symmetric quantization using float8 format. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + max_scale (float, optional): Maximum scaling factor. Defaults to 1.0. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Scale tensor used for quantization (torch.Tensor). + - Placeholder for zp (None). + """ + orig_shape = tensor.shape + info = torch.finfo(torch.float8_e4m3fn) + orig_dtype = tensor.dtype + + tensor = tensor.reshape(-1, orig_shape[-1]) + max_tensor = torch.max(torch.abs(tensor), dim=-1)[ + 0] * max_scale + + scale = max_tensor.to(torch.float32) / info.max + min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm + scale = torch.clip(scale, min=min_scaling_factor) + if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 + tensor = tensor.to(torch.bfloat16) + scale = scale.unsqueeze(dim=-1) + fp8_res = (tensor / scale) + fp8_res = torch.clip(fp8_res, info.min, info.max) + fp8_res = float8_e4m3fn_ste(fp8_res) + qdq_res = fp8_res * scale + qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) + return qdq_res, scale, None + + +@register_dtype("fp8_sym") +def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, **kwargs): + """Symmetric quantization using float8 format. + + Allows both dynamic per-token scaling and tensor-wide quantization depending on input. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + max_scale (float, optional): Maximum scaling factor. Defaults to 1.0. + tensor_max (float, optional): Maximum tensor value for precomputed scale. Defaults to None. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Scale tensor used for quantization (torch.Tensor). + - Placeholder for zp (None). + """ + orig_shape = tensor.shape + info = torch.finfo(torch.float8_e4m3fn) + orig_dtype = tensor.dtype + + if tensor_max is None: ##dynamic per-token + tensor = tensor.reshape(-1, orig_shape[-1]) + max_tensor = torch.max(torch.abs(tensor), dim=-1)[ + 0] * max_scale + elif isinstance(tensor_max,torch.Tensor): + max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale + else: + max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale + scale = max_tensor.to(torch.float32) / info.max + min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm + scale = torch.clip(scale, min=min_scaling_factor) + if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 + tensor = tensor.to(torch.bfloat16) + scale = scale.unsqueeze(dim=-1) + fp8_res = (tensor / scale) + fp8_res = torch.clip(fp8_res, info.min, info.max) + fp8_res = float8_e4m3fn_ste(fp8_res) + qdq_res = fp8_res * scale + qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) + return qdq_res, scale, None + +@register_dtype_static("fp8_sym") +def quant_fp8_sym_static(tensor, bits=4, inited=False, quant_granularity='channel_wise', scale=None, zp=None, + leaf_param=False, running_stat=None, **kwargs): + """Symmetric quantization using float8 format. + + Allows both dynamic per-token scaling and tensor-wide quantization depending on input. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + max_scale (float, optional): Maximum scaling factor. Defaults to 1.0. + tensor_max (float, optional): Maximum tensor value for precomputed scale. Defaults to None. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Scale tensor used for quantization (torch.Tensor). + - Placeholder for zp (None). + """ + orig_shape = tensor.shape + info = torch.finfo(torch.float8_e4m3fn) + orig_dtype = tensor.dtype + + if quant_granularity == 'tensor_wise': + tensor = tensor.reshape(1, -1) + elif quant_granularity == 'channel_wise': + tensor = tensor.reshape(orig_shape[0], -1) + if not inited: + tensor_max = tensor.max(-1)[0] + tensor_min = tensor.min(-1)[0] + max_tensor = torch.max(tensor_min.abs(), tensor_max.abs()) + scale = max_tensor.to(torch.float32) / info.max + min_scaling_factor = float(1.0 / (info.max * 512.0)) + scale = torch.clip(scale, min=min_scaling_factor) + scale = scale.unsqueeze(dim=-1) + inited = True + + if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 + tensor = tensor.to(torch.bfloat16) + + fp8_res = (tensor / scale) + fp8_res = torch.clip(fp8_res, info.min, info.max) + fp8_res = float8_e4m3fn_ste(fp8_res) + qdq_res = fp8_res * scale + qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) + return qdq_res, scale, zp, inited + + +@register_dtype("fp8_gaudi3_sym") +def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs): + """Symmetric quantization using float8 format. + + Allows both dynamic per-token scaling and tensor-wide quantization depending on input. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + max_scale (float, optional): Maximum scaling factor. Defaults to 1.0. + tensor_max (float, optional): Maximum tensor value for precomputed scale. Defaults to None. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Scale tensor used for quantization (torch.Tensor). + - Placeholder for zp (None). + """ + orig_shape = tensor.shape + fp8_max = torch.finfo(torch.float8_e4m3fn).max + orig_dtype = tensor.dtype + + if tensor_max is None: ##dynamic per-te + tensor = tensor.reshape(-1, orig_shape[-1]) + max_tensor = torch.max(torch.abs(tensor), dim=-1)[ + 0] * max_scale + elif isinstance(tensor_max, torch.Tensor): + max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale + else: + max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale + scale = max_tensor.to(torch.float32) / fp8_max + min_scaling_factor = float(1.0 / (fp8_max * 512.0)) ##copy from vllm + scale = torch.clip(scale, min=min_scaling_factor) + if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 + tensor = tensor.to(torch.bfloat16) + scale = scale.unsqueeze(dim=-1) + fp8_res = (tensor / scale) + fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) + float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func() + fp8_res = float8_e4m3fn_ste_gaudi(fp8_res) + qdq_res = fp8_res * scale + qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) + return qdq_res, scale, None diff --git a/auto_round_diff/data_type/int.py b/auto_round_diff/data_type/int.py new file mode 100644 index 00000000..d1040ad9 --- /dev/null +++ b/auto_round_diff/data_type/int.py @@ -0,0 +1,438 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import logging +from .utils import round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad +from auto_round_diff.data_type.register import register_dtype, register_dtype_static +logger = logging.getLogger("autoround") + +def lp_loss(pred, tgt, p=2.0): + """ + loss function measured in L_p Norm + """ + return (pred-tgt).abs().pow(p).mean(1).unsqueeze(dim=-1) + +def quantize(tensor: torch.Tensor, bits: int, sym: bool, tensor_min: torch.Tensor, tensor_max: torch.Tensor, q_scale_thresh: float=1e-5, always_zero: bool=False, qdq: bool=False): + qdq_result = None + if sym: + maxq = 2**(bits - 1) + wmin_abs = tensor_min.abs() # pylint: disable=E1130 + wmax_abs = tensor_max.abs() + max_v = (2 * (wmax_abs < wmin_abs).int() - 1) * torch.max(wmax_abs, wmin_abs) + scale = max_v / maxq + scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) + zp = torch.full_like(scale, maxq) + if qdq: + int_w = round_ste(tensor / scale) + q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + else: + maxq = 2**bits - 1 + scale = ((tensor_max - tensor_min) / maxq) + scale = torch.clamp(scale, min=q_scale_thresh) + zp = round_ste(-tensor_min / scale) if not always_zero else torch.zeros_like(scale) # pylint: disable=E1130 + if qdq: + int_w = round_ste(tensor / scale) + q = torch.clamp(int_w + zp, 0, maxq) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + return scale, zp, qdq_result + +def search_quant_params(tensor: torch.Tensor, sym: bool, bits: int, scale_method: str, leaf_param: bool, always_zero: bool, q_scale_thresh: float=1e-5): + scale, zero_point = None, None + + if leaf_param: + pass + # self.x_min = x.data.min() + # self.x_max = x.data.max() + + if 'max' in scale_method: + tensor_min = torch.clamp(tensor.min(-1)[0], max=0).unsqueeze(dim=-1) + tensor_max = torch.clamp(tensor.max(-1)[0], min=0).unsqueeze(dim=-1) + if 'scale' in scale_method: + tensor_min = tensor_min * (bits + 2) / 8 + tensor_max = tensor_max * (bits + 2) / 8 + + scale, zero_point, _ = quantize(tensor, bits, sym, tensor_min, tensor_max, q_scale_thresh, always_zero, qdq=False) + + elif scale_method == 'mse': + tensor_min = tensor.min(-1)[0].unsqueeze(dim=-1) + tensor_max = tensor.max(-1)[0].unsqueeze(dim=-1) + best_score = torch.full_like(tensor_max, 1e+10) + scale, zero_point = torch.zeros_like(tensor_max), torch.zeros_like(tensor_max) + for i in range(80): + new_min = tensor_min * (1.0 - (i * 0.01)) + new_max = tensor_max * (1.0 - (i * 0.01)) + scale_, zero_point_, tensor_q = quantize(tensor, bits, sym, new_min, new_max, always_zero, qdq=True) + # L_p norm minimization as described in LAPQ + # https://arxiv.org/abs/1911.07190 + score = lp_loss(tensor, tensor_q, p=2.4) + mask = score < best_score + best_score[mask] = score[mask] + scale[mask] = scale_[mask] + zero_point[mask] = zero_point_[mask] + else: + raise NotImplementedError + + return scale, zero_point + + + +@register_dtype("int_sym") +def quant_tensor_sym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, + tensor_min=None, + tensor_max=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + v: Rounding value perturbation + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. + scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantized tensor, scale, zero-point + """ + + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + maxq = 2 ** (bits - 1) + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) + else: + wmin_tmp = tensor_min + wmax_tmp = tensor_max + + wmin_abs = -(wmin_tmp * min_scale) # pylint: disable=E1130 + wmax_abs = wmax_tmp * max_scale + max_v = (2 * (wmax_abs < wmin_abs).int() - 1) * torch.max(wmax_abs, wmin_abs) + scale = (max_v / maxq).to(scale_dtype) + scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) + zp = torch.full_like(scale, maxq) # pylint: disable=E1130 + scale = scale.unsqueeze(dim=-1) + zp = zp.unsqueeze(dim=-1) + int_w = round_ste(tensor / scale + v) + q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + return qdq_result, scale, zp + +@register_dtype_static("int_sym") +def quant_tensor_sym_static(tensor, bits=4, inited=False, quant_granularity='channel_wise', group_size=-1, scale_method='max', + scale=None, zp=None, leaf_param=False, always_zero=False, running_stat=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + v: Rounding value perturbation + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. + scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantized tensor, scale, zero-point + """ + orig_shape = tensor.shape + if quant_granularity == 'tensor_wise': + tensor = tensor.reshape(1, -1) + elif quant_granularity == 'channel_wise': + tensor = tensor.reshape(orig_shape[0], -1) + elif quant_granularity == 'group_wise': + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + + if not inited: + # init scale and zero point + scale, zp = search_quant_params(tensor, bits=bits, scale_method=scale_method, leaf_param=leaf_param, always_zero=always_zero, sym=True) + if leaf_param: + scale = torch.nn.Parameter(scale) + inited = True + + int_w = round_ste(tensor / scale) + q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + + if quant_granularity in ('tensor_wise', 'channel_wise'): + qdq_result = qdq_result.reshape(orig_shape) + else: + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + + return qdq_result, scale, zp, inited + +## the values should be positive +def double_quant_tensor(tensor, bits, q_scale_thresh): + maxq = 2 ** bits - 1 + wmax = torch.clamp(tensor.max(-1)[0], min=0) + scale = torch.clamp(wmax / maxq, q_scale_thresh) + scale = scale.view(-1, 1) + qdq_tensor = torch.clamp(round_ste(tensor / scale), max=maxq) * scale + return qdq_tensor, scale + + +@register_dtype("int_asym_dq") +def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, + tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, super_group_size=8, super_bits=6, + **kwargs): + """Quantize and de-quantize tensor asymmetrically. + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + v: Rounding value perturbation + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. + scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantized tensor, scale, zero-point + """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + + maxq = 2 ** bits - 1 + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) + else: + wmin_tmp = tensor_min + wmax_tmp = tensor_max + if isinstance(min_scale, torch.Tensor): + wmin = wmin_tmp * min_scale + wmax = wmax_tmp * max_scale + else: + wmin = wmin_tmp + wmax = wmax_tmp + scale = ((wmax - wmin) / maxq).to(scale_dtype) + scale = torch.clamp(scale, min=q_scale_thresh) + scale = scale.view(-1, super_group_size) + wmin_m = -wmin # pylint: disable=E1130 + wmin_m = wmin_m.view(-1, super_group_size) + + ##conduct double quant + scale, d_scale = double_quant_tensor(scale, super_bits, q_scale_thresh) + wmin_m, d_wmin_m = double_quant_tensor(wmin_m, super_bits, q_scale_thresh) + + scale = scale.view(-1, 1) + scale = torch.clamp(scale, q_scale_thresh) + wmin_m = wmin_m.view(-1, 1) + + int_w = round_ste((tensor + wmin_m) / scale + v) + q = torch.clamp(int_w, 0, maxq) + qdq_result = (scale * q - wmin_m).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + # zp = round_ste(wmin_m / scale) # remove this later + return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin_m": wmin_m, "d_wmin_m": d_wmin_m} + + +@register_dtype_static("int_asym") +def quant_tensor_asym_static(tensor, bits=4, inited=False, quant_granularity='channel_wise', group_size=-1, scale_method='max', + scale=None, zp=None, leaf_param=False, always_zero=False, running_stat=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically. + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + inited: scale and zero_point is inited or not + quant_granularity: granularity of quantization (channel_wise or group_wise) + scale_method: method for searching initial scale and zero_point + running_stat: using momentum update for activation quantization or not + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantized tensor, scale, zero-point + """ + orig_shape = tensor.shape + if quant_granularity == 'tensor_wise': + tensor = tensor.reshape(1, -1) + elif quant_granularity == 'channel_wise': + tensor = tensor.reshape(orig_shape[0], -1) + elif quant_granularity == 'group_wise': + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + + if not inited: + # init scale and zero point + scale, zp = search_quant_params(tensor, bits=bits, scale_method=scale_method, leaf_param=leaf_param, always_zero=always_zero, sym=False) + if leaf_param: + scale = torch.nn.Parameter(scale) + inited = True + + int_w = round_ste(tensor / scale) + q = torch.clamp(int_w + zp, 0, 2 ** bits - 1) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + + if quant_granularity in ('tensor_wise', 'channel_wise'): + qdq_result = qdq_result.reshape(orig_shape) + else: + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + + return qdq_result, scale, zp, inited + +@register_dtype("int_asym") +def quant_tensor_asym(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, + tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically. + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + v: Rounding value perturbation + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. + scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantized tensor, scale, zero-point + """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + maxq = 2 ** bits - 1 + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) + else: + wmin_tmp = tensor_min + wmax_tmp = tensor_max + if isinstance(min_scale, torch.Tensor): + wmin = wmin_tmp * min_scale + wmax = wmax_tmp * max_scale + else: + wmin = wmin_tmp + wmax = wmax_tmp + scale = ((wmax - wmin) / maxq).to(scale_dtype) + scale = torch.clamp(scale, min=q_scale_thresh) + zp = round_ste(-wmin / scale) # pylint: disable=E1130 + scale = scale.unsqueeze(dim=-1) + zp = zp.unsqueeze(dim=-1) + int_w = round_ste(tensor / scale + v) + q = torch.clamp(int_w + zp, 0, maxq) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + return qdq_result, scale, zp + +@register_dtype("int_sym_gptq") +def quant_tensor_sym_gptq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, + tensor_min=None, + tensor_max=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically. + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + v: Rounding value perturbation + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. + scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantized tensor, scale, zero-point + """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + maxq = 2 ** bits - 1 + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) + else: + wmin_tmp = tensor_min + wmax_tmp = tensor_max + if isinstance(min_scale, torch.Tensor): + wmin = wmin_tmp * min_scale + wmax = wmax_tmp * max_scale + else: + wmin = wmin_tmp + wmax = wmax_tmp + + wmax_new = torch.max(wmin.abs(), wmax) + tmp = wmin < 0 + wmin_new = wmin.clone() ##must clone, otherwise inplace backward will occur + if torch.any(tmp): + wmin_new[tmp] = -wmax_new[tmp] + + scale = ((wmax_new - wmin_new) / maxq).to(scale_dtype) + scale = torch.clamp(scale, min=q_scale_thresh) + scale = scale.unsqueeze(dim=-1) + zp = torch.full_like(scale, (maxq + 1) / 2) + + int_w = round_ste(tensor / scale + v) + q = torch.clamp(int_w + zp, 0, maxq) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + return qdq_result, scale, zp + + +def quant_tensor_asym_wo_round(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, + scale_dtype=torch.float16, + tensor_min=None, tensor_max=None, q_scale_thresh=1e-5, **kwargs): + """Quantize and de-quantize tensor asymmetrically without rounding, this is mainly for tuning bias, norm. + + Args: + tensor: Tensor containing the tensor to be quantized + bits: Number of bits for quantization (e.g., 2, 3, 4, 8) + group_size: Number of elements to share scale for quantization + v: Rounding value perturbation + min_scale: Minimum scale coefficient for tensor + max_scale: Maximum scale coefficient for tensor + tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. + tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. + scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import + q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability + + Returns: + Quantized and de-quantize tensor, scale, zero-point + """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + maxq = 2 ** bits - 1 + if tensor_min is None or tensor_max is None: + wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) + wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) + else: + wmin_tmp = tensor_min + wmax_tmp = tensor_max + if isinstance(min_scale, torch.Tensor): + wmin = wmin_tmp * min_scale + wmax = wmax_tmp * max_scale + else: + wmin = wmin_tmp + wmax = wmax_tmp + + scale = ((wmax - wmin) / maxq).to(scale_dtype) + scale = torch.clamp(scale, min=q_scale_thresh) + zp = -wmin / scale # pylint: disable=E1130 + scale = scale.unsqueeze(dim=-1) + zp = zp.unsqueeze(dim=-1) + int_w = tensor / scale + v + q = torch.clamp(int_w + zp, 0, maxq) + qdq_result = (scale * (q - zp)).to(tensor.dtype) + qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + return qdq_result, scale, zp diff --git a/auto_round_diff/data_type/mxfp.py b/auto_round_diff/data_type/mxfp.py new file mode 100644 index 00000000..2c51f16e --- /dev/null +++ b/auto_round_diff/data_type/mxfp.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from auto_round.data_type.utils import floor_ste, round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad +from auto_round.data_type.register import register_dtype, QUANT_FUNC_WITH_DTYPE + +MXFP_FORMAT_CACHE = { + # data type: ebits, mbits, emax, max_norm, min_norm + "mx_int8": (0, 8, 0, 1.984375, 0), + "mx_int4": (0, 4, 0, 1.75, 0), + "mx_int2": (0, 2, 0, 1.0, 0), + "mx_fp8e5m2": (5, 4, 15, 57344.0, 6.103515625e-05), + "mx_fp8": (4, 5, 8, 448.0, 0.015625), + "mx_fp8e4m3": (4, 5, 8, 448.0, 0.015625), + "mx_fp6e3m2": (3, 4, 4, 28.0, 0.25), + "mx_fp6": (2, 5, 2, 7.5, 1.0), + "mx_fp6e2m3": (2, 5, 2, 7.5, 1.0), + "mx_fp4": (2, 3, 2, 6.0, 1.0), + "mx_fp4e2m1": (2, 3, 2, 6.0, 1.0), + "mx_float16": (5, 12, 15, 65504.0, 6.103515625e-05), + "mx_fp16": (5, 12, 15, 65504.0, 6.103515625e-05), + "mx_bfloat16": (8, 9, 127, 3.3895313892515355e+38, 1.1754943508222875e-38), + "mx_bf16": (8, 9, 127, 3.3895313892515355e+38, 1.1754943508222875e-38), +} + +FP32_EXPONENT_BIAS = 127 +FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) + + +def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"): + if ebits != 0: + private_exp = floor_ste(torch.log2(torch.abs(tensor) + (tensor == 0).type(tensor.dtype))) + + # The minimum representable exponent for 8 exp bits is -126 + min_exp = -(2 ** (ebits - 1)) + 2 + private_exp = private_exp.clip(min=min_exp) + else: + private_exp = None + + # Scale up so appropriate number of mbits are in the integer portion of the number + tensor = tensor * (2 ** (mbits - 2)) if private_exp is None else tensor / (2 ** private_exp) * (2 ** (mbits - 2)) + + if mantissa_rounding == "even": + abs_tensor = torch.abs(tensor) + mask_tensor = ((abs_tensor - 0.5) % 2 == torch.zeros_like(abs_tensor)).type(tensor.dtype) + tensor = torch.sign(tensor) * (floor_ste(abs_tensor + 0.5) - mask_tensor) + elif mantissa_rounding == "nearest": + tensor = round_ste(tensor) + elif mantissa_rounding == "floor": + tensor = floor_ste(tensor) + else: + raise ValueError("mantissa_rounding only supports even, nearest or floor.") + + ##the clamp is False in official code + # max_mantissa = 2 ** (mbits - 1) - 1 ## this is incorrect + # tensor = torch.clamp(tensor, -max_mantissa, max_mantissa) + + # Undo scaling + tensor = tensor / (2 ** (mbits - 2)) if private_exp is None else tensor / (2 ** (mbits - 2)) * (2 ** private_exp) + + tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) + return tensor + + +def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, + mantissa_rounding="even", data_type="mx_fp", **kwargs): + """Quantize the given tensor using the specified parameters. + + This function performs quantization on the `tensor` tensor according to the + given bit width (`bits`), data type (`data_type`), and additional parameters. + The quantization process involves scaling the tensor values and adjusting + the exponent and mantissa to fit within the specified format. + + Args: + tensor (torch.Tensor): The tensor containing the tensors to be quantized. + bits (int): The bit width to be used for quantization. + group_size (int): The group size of sharing scale and exponent. + data_type (str): The data type for quantization (e.g., 'mx_fp4'). + v (float): A value used for adjusting the tensors. + max_scale (float or torch.Tensor): The maximum scale to be applied to the tensors. + mantissa_rounding (str): rounding method for mantissa,currently support even,nearest,floor + + Returns: + tuple: A tuple containing the quantized tensors, shared exponent, and None (reserved for future use). + + Raises: + KeyError: If `data_type` is not found in `MXFP_FORMAT_CACHE`. + """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type] + orig_dtype = tensor.dtype + shared_exp, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True) + if isinstance(max_scale, torch.Tensor): + shared_exp *= (max_scale.unsqueeze(dim=-1)).to(tensor.device) + else: + shared_exp *= max_scale + + shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)) + shared_exp = floor_ste(shared_exp) + # if flush_fp32_subnorms: + # tensor = tensor * (shared_exp > -FP32_EXPONENT_BIAS).type(tensor.dtype) + + scale_emax = 2 ** (8 - 1) - 1 + shared_exp[shared_exp == torch.inf] = scale_emax + emax + shared_exp[shared_exp == -torch.inf] = -scale_emax + emax + shared_exp = (shared_exp - emax) + + shared_exp[shared_exp > scale_emax] = scale_emax ##changed Nan + shared_exp[shared_exp < -scale_emax] = -scale_emax + if (shared_exp.dtype == torch.float16 and (torch.any(shared_exp > 15) or torch.any(shared_exp < -24))) or ( + shared_exp.dtype == torch.bfloat16 and torch.any((shared_exp < -126))): + tensor = tensor.to(torch.float32) + shared_exp = shared_exp.to(torch.float32) + tensor = tensor / (2 ** shared_exp) + tensor = tensor + v + tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding) + + tensor = tensor * (2 ** shared_exp) + tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len) + return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None + + +for key in MXFP_FORMAT_CACHE.keys(): + QUANT_FUNC_WITH_DTYPE[key] = quant_mx + +if __name__ == "__main__": + data = torch.tensor([0.0, 0.25, 0.4,0.75, 1.25,1.4, 1.75, 2.5, 2.9,3.5, 5.0, 5.1]) + data1 = quant_element(data, 2, 3, 6.0) + gt = torch.tensor([0.0,0.0,0.5,1.0,1.0,1.5,2.0,2.0,3.0,4.0,4.0,6.0]) + assert(torch.sum(torch.abs(data1-gt))<1e-6) + + data_neg = data*-1 + data2 = quant_element(data_neg, 2, 3, 6.0) + assert(torch.sum(torch.abs(data2-gt*-1))<1e-6) + + diff --git a/auto_round_diff/data_type/nvfp.py b/auto_round_diff/data_type/nvfp.py new file mode 100644 index 00000000..a60541c9 --- /dev/null +++ b/auto_round_diff/data_type/nvfp.py @@ -0,0 +1,172 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from auto_round.data_type.fp8 import float8_e4m3fn_ste +from auto_round.data_type.register import register_dtype +from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad + + +# taken from +# https://github.com/vllm-project/vllm/blob/ebb554cdb7cd9cc54b2feec20c45ab9cd9067d52/tests/kernels/test_nvfp4_quant.py +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def cast_to_fp4_ste(x): + fp4 = (cast_to_fp4(x).to(x.dtype) - x).detach() + x + + return fp4 + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def ref_nvfp4_quant(x, global_scale, block_size=16, v=0): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = float8_e4m3fn_ste(scale).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + v + clipped_x = torch.clamp(scaled_x, -6.0, 6.0) + return (cast_to_fp4_ste(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), output_scale + + +@register_dtype("nv_fp4") +def full_quant(tensor, bits=4, group_size=16, v=0, **kwargs): + orig_dtype = tensor.dtype + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + tensor_amax = tensor.abs().max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + qdq_res, output_scale = ref_nvfp4_quant(tensor, global_scale, group_size, v) + qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len) + return qdq_res.to(orig_dtype), output_scale, None + + +FLOAT8_UE5M3_MAX = 114688 + + +def float_to_e5m3_frexp(x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, min=0.0) + e5m3 = torch.zeros_like(x, dtype=torch.uint8) + + mask = x > 0 + x_masked = x[mask] + + # 正常数:x >= 2^-14 + normal_mask = x_masked >= 2 ** -14 + x_normal = x_masked[normal_mask] + mantissa, exponent = torch.frexp(x_normal) + + m3 = torch.clamp(torch.round((mantissa - 0.5) * 16), 0, 7).to(torch.uint8) + e5 = torch.clamp(exponent + 14, 0, 31).to(torch.uint8) # 0 reserved for subnormal, 31 reserved for NaN + + e5m3_vals = ((e5 << 3) | m3).to(torch.uint8) + + # sumnorm:0 < x < 2^-14 + subnormal_mask = ~normal_mask + x_subnormal = x_masked[subnormal_mask] + m_sub = torch.clamp(torch.round(x_subnormal / (2 ** -14) * 8), 1, 7).to(torch.uint8) # exponent = 0 + e5m3_sub = m_sub # top 5 bits = 0 + + out_vals = torch.zeros_like(x_masked, dtype=torch.uint8) + out_vals[normal_mask] = e5m3_vals + out_vals[subnormal_mask] = e5m3_sub + + e5m3[mask] = out_vals + return e5m3 + + +def e5m3_to_float_tensor(e5m3: torch.Tensor) -> torch.Tensor: + assert e5m3.dtype == torch.uint8 + + x = torch.zeros_like(e5m3, dtype=torch.float32) + mask_nonzero = e5m3 != 0 + e = ((e5m3[mask_nonzero] >> 3) & 0x1F).to(torch.int32) + m = (e5m3[mask_nonzero] & 0x07).to(torch.int32) + + is_nan = (e == 31) & (m == 7) + is_subnormal = (e == 0) + is_normal = (e > 0) & (~is_nan) + + out = torch.zeros_like(e, dtype=torch.float32) + + # subnormal: exponent = -14, no implicit leading 1 + out[is_subnormal] = (m[is_subnormal].float() / 8.0) * (2 ** -14) + + # normal: exponent = e - 15, implicit leading 1 + mant = 1.0 + m[is_normal].float() / 8.0 + exp = e[is_normal] - 15 + out[is_normal] = torch.ldexp(mant, exp) + + out[is_nan] = float('nan') + x[mask_nonzero] = out + return x + + +def cast_to_ue5m3(tensor): + orig_dtype = tensor.dtype + encoded = float_to_e5m3_frexp(tensor) + res = e5m3_to_float_tensor(encoded) + res = res.to(orig_dtype) + return res + + +def cast_to_ue5m3_ste(x): + fp4 = (cast_to_ue5m3(x).to(x.dtype) - x).detach() + x + + return fp4 + + +if __name__ == "__main__": + test = torch.tensor( + [0.0, 2 ** (-17), (2 ** -14) * 0.875, 2 ** -14, 2 ** -13, 2 ** -6, + 1e-6, 2.7657e-05, 0.1, 1.0, 3.14, 1000.0, + 114688, + 1e10], + dtype=torch.float32) + encoded = float_to_e5m3_frexp(test) + decoded = e5m3_to_float_tensor(encoded) + decoded_bf16 = decoded.to(torch.bfloat16) + print(decoded_bf16) + + for i in range(len(test)): + print( + f"{test[i].item():.6g} -> {encoded[i].item():3d} -> {decoded[i].item():.6g} " + f"(error={abs(test[i] - decoded[i]).item():.3g})") diff --git a/auto_round_diff/data_type/register.py b/auto_round_diff/data_type/register.py new file mode 100644 index 00000000..93276490 --- /dev/null +++ b/auto_round_diff/data_type/register.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +QUANT_FUNC_WITH_DTYPE = {} + + +def register_dtype(name): + """Class decorator to register a EXPORT subclass to the registry. + + Decorator function used before a Pattern subclass. + + Args: + cls (class): The subclass of register. + name: A string. Define the export type. + + Returns: + cls: The class of register. + """ + + def register(dtype): + QUANT_FUNC_WITH_DTYPE[name] = dtype + return dtype + + return register + + +QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE = {} +def register_dtype_static(name): + def register(dtype): + QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE[name] = dtype + return dtype + + return register diff --git a/auto_round_diff/data_type/utils.py b/auto_round_diff/data_type/utils.py new file mode 100644 index 00000000..d1d7c3fc --- /dev/null +++ b/auto_round_diff/data_type/utils.py @@ -0,0 +1,259 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from auto_round_diff.data_type.register import QUANT_FUNC_WITH_DTYPE, QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE +from functools import lru_cache +import logging +logger = logging.getLogger("autoround") + +def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int): + """Reshapes and pads the tensor to ensure that it can be quantized in groups of `group_size`. + + This function adjusts t + he input tensor's shape so that its last dimension is a multiple + of the specified `group_size`. If padding is required, it adds padding to the tensor + to achieve this. If the tensor's last dimension is already divisible by `group_size`, + no padding is applied. + + Args: + data (torch.Tensor): The input tensor to be reshaped and padded. + group_size (int): The size of the groups that the tensor should be reshaped into. + + Returns: + torch.Tensor: The reshaped and padded tensor, if necessary. + tuple: The original shape of the input tensor. + int: The padding length applied to the tensor. Returns 0 if no padding is applied. + """ + orig_shape = data.shape + pad_len = 0 + if len(data.shape) > 2: + data = data.reshape(-1, orig_shape[-1]) + if group_size == -1 or data.shape[1] < group_size: + return data, orig_shape, pad_len + elif data.shape[1] % group_size == 0: + data = data.reshape(-1, group_size) + return data, orig_shape, pad_len + else: + pad_len = (data.shape[1] + group_size - 1) // group_size * group_size - data.shape[1] + data_new = torch.nn.functional.pad(data, (0, pad_len)) + data_new = data_new.reshape(-1, group_size) + return data_new, orig_shape, pad_len + +def reshape_tensor_by_quant_level(data: torch.Tensor, quant_granularity: str, group_size: int): + data_new, orig_shape, pad_len = None, data.shape, None + if quant_granularity == 'tensor_wise': + data_new = data.reshape(1, -1) + elif quant_granularity == 'channel_wise': + data_new = data.reshape(orig_shape[0], -1) + else: + data_new, orig_shape, pad_len = reshape_pad_tensor_by_group_size(data, group_size) + + return data_new, orig_shape, pad_len + +def revert_tensor_by_pad(data: torch.Tensor, orig_shape: tuple, pad_len: int): + """Reverts the tensor to its original shape by removing padding. + + This function removes the padding added during reshaping and returns the tensor to + its original shape. + + Args: + data (torch.Tensor): The reshaped and possibly padded tensor. + orig_shape (tuple): The original shape of the tensor before reshaping. + pad_len (int): The length of the padding to be removed. + + Returns: + torch.Tensor: The tensor restored to its original shape. + """ + if pad_len == 0: + return data.reshape(orig_shape) + else: + if len(orig_shape) > 2: + tmp_shape = torch.prod(torch.tensor(orig_shape[:-1])).item() + else: + tmp_shape = orig_shape[0] + data_new = data.reshape(tmp_shape, -1) + data_new = data_new[:, :-pad_len] + data_new = data_new.reshape(orig_shape) + return data_new + + +def get_quant_func(dtype, bits, sym): + """Retrieve the quantization function based on data type, bit width, and symmetry. + + This function returns the appropriate quantization function from the QUANT_FUNC_WITH_DTYPE + dictionary based on the provided data type (`dtype`), bit width (`bits`), and whether + the quantization is symmetric (`sym`). If the function does not exist, it asserts False. + + Args: + dtype (str): The data type for the quantization (e.g., 'int', 'mxfp4'). + bits (int): The bit width for the quantization (e.g., 2,4,8). + sym (bool): A flag indicating whether the quantization is symmetric (True) or asymmetric (False). + + Returns: + function: The quantization function corresponding to the specified parameters. + """ + key = dtype + if key in QUANT_FUNC_WITH_DTYPE.keys(): + return QUANT_FUNC_WITH_DTYPE[key], key + + if sym: + key = dtype + "_sym" + else: + key = dtype + "_asym" + + if key in QUANT_FUNC_WITH_DTYPE.keys(): + return QUANT_FUNC_WITH_DTYPE[key], key + + ##need to add bits + if sym: + key = dtype + str(bits) + "_sym" + else: + key = dtype + str(bits) + "_asym" + + if key in QUANT_FUNC_WITH_DTYPE.keys(): + return QUANT_FUNC_WITH_DTYPE[key], key + + if sym: + key = dtype + "_sym" + else: + key = dtype + "_asym" + + if key in QUANT_FUNC_WITH_DTYPE.keys(): + return QUANT_FUNC_WITH_DTYPE[key], key + + if sym: + key = dtype + str(bits) + else: + key = dtype + str(bits) + + if key in QUANT_FUNC_WITH_DTYPE.keys(): + return QUANT_FUNC_WITH_DTYPE[key], key + + assert False, f"{dtype} is not supported" + +def get_static_quant_func(dtype, bits, sym): + + key = dtype + if key in QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE.keys(): + return QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE[key], key + + if sym: + key = dtype + "_sym" + else: + key = dtype + "_asym" + + if key in QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE.keys(): + return QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE[key], key + + ##need to add bits + if sym: + key = dtype + str(bits) + "_sym" + else: + key = dtype + str(bits) + "_asym" + + if key in QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE.keys(): + return QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE[key], key + + if sym: + key = dtype + "_sym" + else: + key = dtype + "_asym" + + if key in QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE.keys(): + return QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE[key], key + + if sym: + key = dtype + str(bits) + else: + key = dtype + str(bits) + + if key in QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE.keys(): + return QUANT_FUNC_WITH_DTYPE_AND_STATIC_SCALE[key], key + + assert False, f"{dtype} is not supported" + + +def round_ste(x: torch.Tensor): + """Straight-Through Estimator for rounding. + + Args: + x: torch.Tensor + + Returns: + torch.Tensor + """ + return (x.round() - x).detach() + x + + +def floor_ste(x: torch.Tensor): + """Straight-Through Estimator for floor. + + Args: + x: torch.Tensor + + Returns: + torch.Tensor + """ + return (x.floor() - x).detach() + x + +def float8_e4m3fn_ste(x: torch.Tensor): + """Straight-Through Estimator (STE) for float8. + + Applies a quantization and dequantization step with float8 precision while maintaining + gradient flow using a straight-through estimator. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Quantized and dequantized tensor using float8 format. + """ + fp8 = (x.to(x.dtype) - x).detach() + x + + return fp8 + + +def float8_e4m3fn_hpu_ste(x: torch.Tensor): + """Straight-Through Estimator (STE) for float8. + + Applies a quantization and dequantization step with float8 precision while maintaining + gradient flow using a straight-through estimator. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Quantized and dequantized tensor using float8 format. + """ + fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x + + return fp8 + + +@lru_cache(None) +def get_gaudi_fp8_ste_func(): + from auto_round.utils import is_hpu_supported + + if is_hpu_supported(): + fn = float8_e4m3fn_hpu_ste + logger.warning_once("Using HPU STE for FP8") + else: + fn = float8_e4m3fn_ste + logger.warning_once("Using CUDA/CPU STE for FP8") + return fn + + + + diff --git a/auto_round_diff/data_type/w4fp8.py b/auto_round_diff/data_type/w4fp8.py new file mode 100644 index 00000000..9b512375 --- /dev/null +++ b/auto_round_diff/data_type/w4fp8.py @@ -0,0 +1,244 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from auto_round.data_type.register import register_dtype +from auto_round.data_type.utils import get_gaudi_fp8_ste_func, float8_e4m3fn_ste + + +# @register_dtype("fp8_gaudi3_to_int_sym") +# def progressive_quant_fp8_int4_gaudi3( +# tensor, +# bits=4, +# group_size=-1, +# v=0, +# min_scale=1.0, +# max_scale=1.0, +# q_scale_thresh=1e-5, +# weight_fp8_max_scale=1.0, +# **kwargs +# ): +# """Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128 +# +# This method first quantizes the input tensor into float8 format and then performs +# a secondary quantization to int4 with grouping. +# +# Args: +# tensor (torch.Tensor): Input tensor to quantize. +# bits (int, optional): Bit precision for secondary quantization. Defaults to 4. +# group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping). +# v (float, optional): Optional parameter for variance tuning. Defaults to 0. +# min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0. +# max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0. +# q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5. +# weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0. +# **kwargs: Additional arguments for compatibility. +# +# Returns: +# tuple: +# - Quantized and dequantized tensor (torch.Tensor). +# - Combined scaling factor (torch.Tensor). +# - Placeholder for zp (None). +# """ +# fp8_max = torch.finfo(torch.float8_e4m3fn).max +# tensor_max = ( +# torch.max(torch.abs(tensor)).to(torch.float32) * weight_fp8_max_scale +# ) ## better train a ratio +# scale = tensor_max.to(torch.float32) / fp8_max +# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm +# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) +# fp8_res = tensor / scale_bf16_to_fp8 +# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) +# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func() +# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res) +# +# # convert to bf16 +# fp8_res_using_16bit = fp8_res.to(tensor.dtype) +# # convert to int4 +# from auto_round.data_type.int import quant_tensor_sym +# +# qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym( +# fp8_res_using_16bit, +# bits=bits, +# group_size=group_size, +# v=v, +# min_scale=min_scale, +# max_scale=max_scale, +# scale_dtype=torch.bfloat16, +# q_scale_thresh=q_scale_thresh, +# ) +# qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8 +# scale_bf16_to_int4 = scale_fp8_to_int4 * scale_bf16_to_fp8 +# return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4 + + +# @register_dtype("fp8_gaudi3_to_int_sym_pc") +# def progressive_quant_fp8_int4_per_channel( +# tensor, +# bits=4, +# group_size=-1, +# v=0, +# min_scale=1.0, +# max_scale=1.0, +# q_scale_thresh=1e-5, +# weight_fp8_max_scale=1.0, +# **kwargs +# ): +# """The per-channel version of progressive quantization from float8 to int4.""" +# # tensor: [out_feats, in_feats] +# # scale_bf16_to_fp8: [out_feats, 1] +# out_feats, in_feats = tensor.shape +# fp8_max = torch.finfo(torch.float8_e4m3fn).max +# dim = 1 +# tensor_max = ( +# torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0].to(torch.float32) +# * weight_fp8_max_scale +# ) ## better train a ratio +# scale = tensor_max.to(torch.float32) / fp8_max +# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm +# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) +# fp8_res = tensor / scale_bf16_to_fp8 +# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) +# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func() +# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res) +# +# ##convert to bf16 +# fp8_res_using_16bit = fp8_res.to(tensor.dtype) +# ##convert to int4 +# from auto_round.data_type.int import quant_tensor_sym +# +# qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym( +# fp8_res_using_16bit, +# bits=bits, +# group_size=group_size, +# v=v, +# min_scale=min_scale, +# max_scale=max_scale, +# scale_dtype=torch.bfloat16, +# q_scale_thresh=q_scale_thresh, +# ) +# qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8 +# scale_fp8_to_int4_with_group = scale_fp8_to_int4 +# scale_fp8_to_int4_with_group_reshape_back = scale_fp8_to_int4_with_group.reshape( +# out_feats, -1 +# ) +# scale_bf16_to_int4 = scale_fp8_to_int4_with_group_reshape_back * scale_bf16_to_fp8 +# scale_bf16_to_int4_with_group = scale_bf16_to_int4.reshape(-1, 1) +# return ( +# qdq_tensor, +# (scale_bf16_to_int4_with_group, scale_bf16_to_fp8), +# zp_fp8_to_int4, +# ) + + +# @register_dtype("fp8_gaudi3_to_int_sym_v2") +# def progressive_quant_fp8_int4_v2( +# tensor, +# bits=4, +# group_size=-1, +# v=0, +# min_scale=1.0, +# max_scale=1.0, +# q_scale_thresh=1e-5, +# weight_fp8_max_scale=1.0, +# **kwargs +# ): +# """The variant of progressive quantization from float8 to int4. +# +# The variant quantizes the tensor to int4 first and then quantizes the qdq tensor to fp8. +# """ +# # convert to int4 first +# from auto_round.data_type.int import quant_tensor_sym +# +# qdq_int4_tensor, scale_bf16_to_int4, zp_fp8_to_int4 = quant_tensor_sym( +# tensor, +# bits=bits, +# group_size=group_size, +# v=v, +# min_scale=min_scale, +# max_scale=max_scale, +# scale_dtype=torch.bfloat16, +# q_scale_thresh=q_scale_thresh, +# ) +# # FIXME(Yi): some fuse error here +# torch._dynamo.graph_break() +# fp8_max = torch.finfo(torch.float8_e4m3fn).max +# tensor_max = ( +# torch.max(torch.abs(qdq_int4_tensor)).to(torch.float32) * weight_fp8_max_scale +# ) ## better train a ratio +# scale = tensor_max.to(torch.float32) / fp8_max +# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm +# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) +# fp8_res = qdq_int4_tensor / scale_bf16_to_fp8 +# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) +# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func() +# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res) +# +# # convert to bf16 +# fp8_res_using_16bit = fp8_res.to(tensor.dtype) +# +# qdq_tensor = fp8_res_using_16bit * scale_bf16_to_fp8 +# +# return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4 + + +@register_dtype("fp8_to_int_sym") +def progressive_quant_fp8_int4(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, + q_scale_thresh=1e-5, **kwargs): + """Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128 + + This method first quantizes the input tensor into float8 format and then performs + a secondary quantization to int4 with grouping. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + bits (int, optional): Bit precision for secondary quantization. Defaults to 4. + group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping). + v (float, optional): Optional parameter for variance tuning. Defaults to 0. + min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0. + max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0. + q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Combined scaling factor (torch.Tensor). + - Placeholder for zp (None). + """ + + info = torch.finfo(torch.float8_e4m3fn) + tensor_max = torch.max(torch.abs(tensor)).to(torch.float32) + scale = tensor_max.to(torch.float32) / info.max + min_scaling_factor = 1.0 / (info.max * 512.0) ##copy from vllm + bf16_to_fp8_scale = torch.clip(scale, min=min_scaling_factor) + fp8_res = tensor / bf16_to_fp8_scale + fp8_res = torch.clip(fp8_res, info.min, info.max) + fp8_res = float8_e4m3fn_ste(fp8_res) + + ##convert to bf16 + fp8_res_using_16bit = fp8_res.to(tensor.dtype) + ##convert to int4 + from auto_round.data_type.int import quant_tensor_sym + qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym(fp8_res_using_16bit, bits=bits, + group_size=group_size, v=v, + min_scale=min_scale, + max_scale=max_scale, + scale_dtype=torch.bfloat16, + q_scale_thresh=q_scale_thresh) + qdq_tensor = qdq_int4_tensor * bf16_to_fp8_scale + + bf16_to_int4_scale = scale_fp8_to_int4 * bf16_to_fp8_scale + return qdq_tensor, {"scale": bf16_to_int4_scale, "bf16_to_fp8_scale": bf16_to_fp8_scale}, zp_fp8_to_int4 \ No newline at end of file diff --git a/auto_round_diff/diffusion/__init__.py b/auto_round_diff/diffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/auto_round_diff/diffusion/autoround_diffusion.py b/auto_round_diff/diffusion/autoround_diffusion.py new file mode 100644 index 00000000..28e15e00 --- /dev/null +++ b/auto_round_diff/diffusion/autoround_diffusion.py @@ -0,0 +1,967 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union +from tqdm import tqdm +from copy import deepcopy +from transformers import set_seed +import torch +from tqdm import tqdm, trange + +from ..utils import ( + detect_device, + to_device, + to_dtype, + clear_memory, + supported_layer_types +) + +from wrapper_block import WapperBasicTransformerBlock, WapperResBlock, WapperQKMatMul, WapperSMVMatMul, WapperBasicTransformerBlock, WapperAttnBlock, get_specials +from ldm.modules.diffusionmodules.openaimodel import ResBlock +from ldm.modules.attention import BasicTransformerBlock +from ..wrapper_layer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer +from ..autoround import AutoRoundDM, AdaRoundDM + +from torch.utils.data.dataset import Dataset +from torch.utils.data import DataLoader +import gc +import numpy as np +import logging +logger = logging.getLogger("autoround") + +# class AutoRoundDiffusion(AutoRoundDM): +# """Class for automatic rounding-based quantization with MLLMs. + +# Args: +# model: The PyTorch model to be quantized. +# tokenizer: An optional tokenizer for processing input data. +# processor: Any multi-modal model will require an object to encode or +# decode the data that groups several modalities (among text, vision and audio). +# image_processor: Image processor for special model like llava. +# bits (int): Number of bits for quantization (default is 4). +# group_size (int): Size of the quantization group (default is 128). +# sym (bool): Whether sym to be used (default is True). +# layer_config (dict): Configuration for weight quantization (default is None). +# batch_size (int): Batch size for training (default is 8). +# amp (bool): Whether to use automatic mixed precision (default is True). +# device: The device to be used for training (default is "auto"). +# lr_scheduler: The learning rate scheduler to be used. +# dataset: The path or name of the calib dataset. +# extra_data_dir: The path of extra data such as images, audio and videos. +# template: The path or name of template used to specify process for different MLLMs. +# quant_nontext_module: Whether to quantize nontext module. +# enable_quanted_input (bool): Whether to use quantized input data (default is True). +# enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). +# lr (float): The learning rate (default is 0.005). +# minmax_lr (float): The learning rate for min-max tuning (default is None). +# low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). +# low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). +# iters (int): Number of iterations (default is 200). +# seqlen (int): Length of the sequence. +# nsamples (int): Number of samples (default is 128). +# sampler (str): The sampling method (default is "rand"). +# seed (int): The random seed (default is 42).s +# nblocks (int): Number of blocks (default is 1). +# gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). +# not_use_best_mse (bool): Whether to use mean squared error (default is False). +# dynamic_max_gap (int): The dynamic maximum gap (default is -1). +# data_type (str): The data type to be used (default is "int"). +# scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels +# have different choices. +# act_bits (int): Number of bits for activation quantization. Default is 32. +# act_group_size (int): Group size for activation quantization. Default is None. +# act_sym (bool): Whether to use symmetric activation quantization. Default is None. +# act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. +# to_quant_block_names (str|list): A string or list whose elements are list of +# block's layer names to be quantized. +# enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer +# **kwargs: Additional keyword arguments. + + +# """ + +# def __init__( +# self, +# model: torch.nn.Module, +# # tokenizer, +# # processor = None, +# # image_processor = None, +# bits: int = 4, +# level: str = 'group_size', +# group_size: int = 128, +# sym: bool = True, +# layer_config: dict = None, +# batch_size: int = 8, +# amp: bool = True, +# device: str = None, +# lr_scheduler=None, +# dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = None, +# extra_data_dir: str = None, +# quant_nontext_module: bool = False, +# enable_quanted_input: bool = True, +# enable_minmax_tuning: bool = True, +# lr: float = None, +# minmax_lr: float = None, +# low_gpu_mem_usage: bool = False, +# low_cpu_mem_usage: bool = False, +# iters: int = 200, +# seqlen: int = None, +# nsamples: int = 128, +# sampler: str = "rand", +# seed: int = 42, +# nblocks: int = 1, +# gradient_accumulate_steps: int = 1, +# not_use_best_mse: bool = False, +# dynamic_max_gap: int = -1, +# data_type: str = "int", +# scale_dtype: str = "fp16", +# act_bits: int = 32, +# act_group_size: int = None, +# act_sym: bool = None, +# act_dynamic: bool = True, +# to_quant_block_names: Union[str, list] = None, +# enable_norm_bias_tuning: bool = False, +# truncation: bool = None, +# enable_torch_compile: bool = False, +# model_kwargs: dict = None, +# **kwargs, +# ): +# all_blocks = get_block_names(model, quant_nontext_module) +# self.quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names) +# if to_quant_block_names is None: +# to_quant_block_names = extract_block_names_to_str(self.quant_block_list) +# self.to_quant_block_names = to_quant_block_names +# self.extra_data_dir = extra_data_dir +# self.quant_nontext_module = quant_nontext_module +# self.processor = processor +# self.image_processor = image_processor +# self.template = template if template is not None else model.config.model_type +# if not isinstance(dataset, torch.utils.data.DataLoader): +# self.template = get_template( +# self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor) +# dataset = self.template.default_dataset if dataset is None else dataset + +# model = _handle_special_model(model) + +# from ..calib_dataset import CALIB_DATASETS +# from .mllm_dataset import MLLM_DATASET +# if isinstance(dataset, str): +# if quant_nontext_module or \ +# (dataset in CALIB_DATASETS.keys() and not \ +# _only_text_test(model, tokenizer, device, self.template.model_type)): +# if quant_nontext_module: +# logger.warning(f"Text only dataset cannot be used for calibrating non-text modules," +# "switching to liuhaotian/llava_conv_58k") +# else: +# logger.warning(f"{model.config.model_type} not support for {dataset}," +# " will use liuhaotian/llava_conv_58k with default config as an alternative.") +# dataset = "liuhaotian/llava_conv_58k" + +# if dataset in MLLM_DATASET.keys(): +# truncation = False +# seqlen = 512 if seqlen is None else seqlen +# if batch_size != 1: +# logger.warning( +# f"reset batch_size({batch_size}) to 1 and " +# f"gradient_accumulate_steps({gradient_accumulate_steps}) " +# f"to {batch_size * gradient_accumulate_steps}, " +# f"because batch_size={batch_size} cannot be used for {dataset}") +# gradient_accumulate_steps = batch_size * gradient_accumulate_steps +# batch_size = 1 +# if quant_nontext_module and batch_size != 1: +# logger.warning( +# f"reset batch_size({batch_size}) to 1 and " +# f"gradient_accumulate_steps({gradient_accumulate_steps}) " +# f"to {batch_size * gradient_accumulate_steps}, " +# f"because batch_size={batch_size} cannot be used for calibrating non-text modules.") +# gradient_accumulate_steps = batch_size * gradient_accumulate_steps +# batch_size = 1 +# seqlen = 2048 if seqlen is None else seqlen +# truncation = True if truncation is None else truncation +# self.truncation = truncation + +# if nsamples % batch_size != 0: +# nsamples = (nsamples // batch_size + 1) * batch_size +# logger.warning(f"'nsamples' is not divisible by 'batch_size', will adjusted to {nsamples}") + +# super(AutoRoundMLLM, self).__init__( +# model=model, +# tokenizer=tokenizer, +# bits=bits, +# group_size=group_size, +# sym=sym, +# layer_config=layer_config, +# batch_size=batch_size, +# amp=amp, +# device=device, +# lr_scheduler=lr_scheduler, +# dataset=dataset, +# enable_quanted_input=enable_quanted_input, +# enable_minmax_tuning=enable_minmax_tuning, +# lr=lr, +# minmax_lr=minmax_lr, +# low_gpu_mem_usage=low_gpu_mem_usage, +# low_cpu_mem_usage=low_cpu_mem_usage, +# iters=iters, +# seqlen=seqlen, +# nsamples=nsamples, +# sampler=sampler, +# seed=seed, +# nblocks=nblocks, +# gradient_accumulate_steps=gradient_accumulate_steps, +# not_use_best_mse=not_use_best_mse, +# dynamic_max_gap=dynamic_max_gap, +# data_type=data_type, +# scale_dtype=scale_dtype, +# act_bits=act_bits, +# act_group_size=act_group_size, +# act_sym=act_sym, +# act_dynamic=act_dynamic, +# to_quant_block_names=self.to_quant_block_names, +# enable_norm_bias_tuning=enable_norm_bias_tuning, +# enable_torch_compile=enable_torch_compile, +# **kwargs, +# ) + +# def calib(self, nsamples, bs): +# """Perform calibration for quantization. + +# This method calibrates the model for quantization by processing a specified +# number of samples from the calibration dataset. It ensures that the data is +# properly formatted and feeds it to the model. If the number of samples processed +# is less than the specified number, it logs a warning. If no samples are processed, +# it logs an error and exits. +# Args: +# nsamples (int): The number of samples to use for calibration. +# bs (int): The number of samples to use for calibration +# """ +# if isinstance(self.dataset, str): +# dataset = self.dataset.replace(" ", "") +# self.dataloader, self.batch_size, self.gradient_accumulate_steps = get_mllm_dataloader( +# template=self.template, +# model=self.model, +# tokenizer=self.tokenizer, +# processor=self.processor, +# image_processor=self.image_processor, +# dataset=dataset, +# extra_data_dir=self.extra_data_dir, +# seqlen=self.seqlen, +# bs=self.batch_size, +# seed=self.seed, +# truncation=self.truncation, +# nsamples=self.nsamples, +# gradient_accumulate_steps=self.gradient_accumulate_steps, +# quant_nontext_module=self.quant_nontext_module +# ) +# else: +# self.dataloader = self.dataset +# total_cnt = 0 + +# if self.low_cpu_mem_usage: +# embed_layers = get_layers_before_block(self.model) +# for n, m in embed_layers: +# m = m.to(self.device) + +# total = nsamples if not hasattr(self.dataloader, "len") else min(nsamples, len(self.dataloader)) +# with tqdm(range(1, total + 1), desc="cache block inputs") as pbar: +# for data in self.dataloader: +# if data is None: +# pbar.update(1) +# continue +# if isinstance(data, torch.Tensor): +# input_ids = data.to(self.device) +# data_new = input_ids +# elif isinstance(data, str): +# if self.tokenizer is None: +# logger.error("please provide tokenizer for string input") +# exit() +# # data = self.template._encode(data) +# data = self.template.processor.get_input( +# text=data, +# images=None, +# max_length=self.seqlen, +# squeeze=False, +# ) +# data_new = {} +# for key in data.keys(): +# data_new[key] = data[key].to(self.device) +# input_ids = data_new["input_ids"] +# elif isinstance(data, dict) and "text" in data.keys(): +# text = data['text'] +# if isinstance(text, dict): +# text = [text] +# input_text = self.template._encode(text) +# data = self.template.processor.get_input( +# text=input_text, +# images=data["image"], +# max_length=self.seqlen, +# squeeze=False, +# ) +# data_new = {} +# for key in data.keys(): +# data_new[key] = torch.tensor(data[key]) +# data_new[key] = to_device(data_new[key], self.model.device) +# if key == 'images': +# data_new[key] = to_dtype(data_new[key], self.model.dtype) +# input_ids = data_new["input_ids"] +# elif isinstance(data, tuple) or isinstance(data, list): +# data_new = data +# input_ids = data_new[0] +# else: +# data_new = {} +# for key in data.keys(): +# data_new[key] = to_device(data[key], self.model.device) +# if key in ['images', 'pixel_values']: +# data_new[key] = to_dtype(data_new[key], self.model.dtype) +# if "input_ids" in data_new: +# input_ids = data_new["input_ids"] +# else: +# input_ids = data_new["inputs_embeds"] + +# if input_ids.shape[-1] < self.seqlen: +# pbar.update(1) +# continue +# try: +# if isinstance(data_new, torch.Tensor): +# self.model(data_new) +# elif isinstance(data_new, tuple) or isinstance(data_new, list): +# self.model(*data_new) +# else: +# self.model(**data_new) +# except NotImplementedError: +# pass +# except Exception as error: +# raise error +# step = input_ids.shape[0] if len(input_ids.shape) > 1 else 1 +# total_cnt += step +# pbar.update(step) +# if total_cnt >= nsamples: +# break +# if total_cnt == 0: +# logger.error( +# f"no data has been cached, please provide more data with sequence length >={self.seqlen} in the " +# f"dataset or decease the sequence length" +# ) +# exit(-1) +# elif total_cnt < nsamples: +# logger.warning( +# f"Insufficient number of samples collected may affect the quantization. " +# f"target samples count is {nsamples}, while valid samples count is {total_cnt}" +# ) +# if total_cnt < self.batch_size: +# raise ValueError(f"valid samples is less than batch_size({self.batch_size})," +# " please adjust self.batch_size or seqlen.") +# max_len = (total_cnt // self.batch_size) * self.batch_size +# for k, v in self.inputs.items(): +# for key in v: +# if isinstance(v[key], list) and len(v[key]) == total_cnt: +# self.inputs[k][key] = v[key][:max_len] + +# # clean embed weight to save memory +# if self.low_cpu_mem_usage: +# for n, m in embed_layers: +# m = m.to("meta") +# # torch.cuda.empty_cache() + +# def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): +# """Save the quantized model to the specified output directory in the specified format. + +# Args: +# output_dir (str, optional): The directory to save the quantized model. Defaults to None. +# format (str, optional): The format in which to save the model. Defaults to "auto_round". +# inplace (bool, optional): Whether to modify the model in place. Defaults to True. +# **kwargs: Additional keyword arguments specific to the export format. + +# Returns: +# object: The compressed model object. +# """ +# if self.processor is not None and not hasattr(self.processor, "chat_template"): +# self.processor.chat_template = None +# compressed_model = super().save_quantized( +# output_dir=output_dir, format=format, inplace=inplace, processor=self.processor, **kwargs) +# return compressed_model + +class DiffusionInputDataset(Dataset): + + def __init__(self, data_path): + data_list = torch.load(data_path, map_location='cpu') ## its a list of tuples of tensors + self.xt_list = [] + self.t_list = [] + self.y_list = [] + ## datalist[i][0].shape (B,4,32,32), flat B dimension + for i in range(len(data_list)): + for b in range(len(data_list[i][0])): + self.xt_list.append(data_list[i][0][b]) + self.t_list.append(data_list[i][1][b]) + self.y_list.append(data_list[i][2][b]) + + def __len__(self): + return len(self.xt_list) + + def __getitem__(self, idx): + return self.xt_list[idx], self.t_list[idx], self.y_list[idx] + +class AdaRoundUnetDiffusion(object): + """Class for adaptive rounding-based quantization with Diffusion Models. + + Args: + model: The PyTorch model to be quantized. + tokenizer: An optional tokenizer for processing input data. + processor: Any multi-modal model will require an object to encode or + decode the data that groups several modalities (among text, vision and audio). + image_processor: Image processor for special model like llava. + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether sym to be used (default is True). + layer_config (dict): Configuration for weight quantization (default is None). + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for training (default is "auto"). + lr_scheduler: The learning rate scheduler to be used. + dataset: The path or name of the calib dataset. + extra_data_dir: The path of extra data such as images, audio and videos. + template: The path or name of template used to specify process for different MLLMs. + quant_nontext_module: Whether to quantize nontext module. + enable_quanted_input (bool): Whether to use quantized input data (default is True). + enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). + lr (float): The learning rate (default is 0.005). + minmax_lr (float): The learning rate for min-max tuning (default is None). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). + low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). + iters (int): Number of iterations (default is 200). + seqlen (int): Length of the sequence. + nsamples (int): Number of samples (default is 128). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42).s + nblocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + have different choices. + act_bits (int): Number of bits for activation quantization. Default is 32. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. + to_quant_block_names (str|list): A string or list whose elements are list of + block's layer names to be quantized. + enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer + **kwargs: Additional keyword arguments. + + + """ + + def __init__( + self, + model: torch.nn.Module, + prompts_path: str = None, # cali prompts + weight_bits: int = 4, + sym_w: bool = False, + w_quant_granularity: str = 'channel_wise', + batch_size: int = 8, + w_group_size: int = 128, + data_type_w: str = 'int', + w_scale_method: str = 'max', + cali_iters_w: int = 20000, + quant_act: bool = False, + act_bits: int = 8, + act_quant_granularity: str = 'channel_wise', + act_group_size: int = None, + sym_act: bool = None, + act_dynamic: bool = True, + data_type_act: str = 'int', + act_scale_method: str = 'max', + running_stat: bool = False, + rs_sm_only: bool = False, + sm_abit: int = 8, + cali_iters_a: int = 5000, + device: str = None, + lr_scheduler = None, + enable_quanted_input: bool = True, + lr_a: float = None, + lr_w: float = None, + seed: int = 42, + tune: bool = False, + cali_data_path: str = None, + resume_w: bool = False, + split: bool = True, + **kwargs, + ): + self.quantized = False + self.model_orig_dtype = model.dtype + self.prompts_path = prompts_path + self.supported_types = supported_layer_types + self.cali_data_path = cali_data_path + self.resume_w = resume_w + self.tune = tune + self.seed = seed + set_seed(self.seed) + + # weight quant params + self.weight_bits = weight_bits + self.sym_w = sym_w + self.w_quant_granularity = w_quant_granularity + self.w_group_size = w_group_size if self.w_quant_granularity == 'group_wise' else -1 + self.data_type_w = data_type_w + self.cali_iters_w = cali_iters_w + self.w_scale_method = w_scale_method + self.lr_w = lr_w + self.enable_quanted_input = enable_quanted_input + self.optimizer = torch.optim.Adam + self.lr_w = lr_w + self.lr_scheduler = lr_scheduler + self.quant_act = quant_act + + ## activation + if self.quant_act: + self.act_quant_granularity = act_quant_granularity + self.act_group_size = act_group_size if self.act_quant_granularity == 'group_wise' else -1 + self.act_bits = act_bits if not (act_bits is None) else self.bits + self.act_sym = sym_act + self.act_dynamic = act_dynamic + self.act_data_type = data_type_act + self.act_scale_method = act_scale_method + self.lr_a = lr_a + self.running_stat_a = running_stat + self.rs_sm_only_a = rs_sm_only + self.sm_abit = sm_abit + self.cali_iters_a = cali_iters_a + + self.layer_config = {} + self.batch_size = batch_size + self.ldm = model.eval() + self.model = self.ldm.model.diffusion_model + self.device = detect_device(device) + setattr(self.model, "split", True) + + torch.set_printoptions(precision=3, sci_mode=True) + self.check_configs() + + self.serialization_keys = [ + "weight_bits", + "sym_w", + "w_quant_granularity", + "batch_size", + "w_group_size", + "data_type_w", + "w_scale_method", + "lr_w", + "cali_iters_w", + "quant_act", + "act_bits", + "act_quant_granularity", + "act_group_size", + "sym_act", + "act_dynamic", + "data_type_act", + "act_scale_method", + "cali_iters_a", + "lr_a", + "running_stat_a", + "sm_abit", + "tune", + "enable_quanted_input" + ] + + self.set_layerwise_config(self.layer_config) ##better place in the end + # self.shared_cache_keys = get_shared_keys(self.model) + + def check_configs(self): + + """Checks if the configurations are valid. + + Raises: + AssertionError: If any of the configurations are invalid. + """ + assert isinstance(self.model, torch.nn.Module) + assert self.weight_bits > 0, "bits must be positive" + assert self.w_group_size == -1 or self.group_size >= 1, "only supports positive group_size or -1(per channel)" + assert self.batch_size > 0, "batch size must be positive" + if self.quant_act: + self.act_bits > 0, "bits must be positive" + assert self.act_group_size == -1 or self.act_group_size >= 1, \ + "only supports positive group_size or -1(per channel)" + + def quantize_and_save(self, output_dir: str = "tmp_autoround", format: str = "auto_round", inplace=True, **kwargs): + """Quantizes the model and saves it in the specified format(s). + + This function checks the validity of the requested format(s), quantizes + the model accordingly, and saves it to the specified output directory. + If multiple formats are provided, the model is saved separately for each format. + + Args: + output_dir (str, optional): The directory where the quantized model + will be saved. Defaults to "tmp_autoround". + format (str, optional): The quantization format(s) to use, separated + by commas if multiple. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place if only + one format is used. Defaults to True. + **kwargs: Additional arguments for the quantization and saving process. + + Returns: + model: A qdq model or packed model based on the configurations + folders: The folder paths where the quantized models are saved. + + Raises: + ValueError: If an unsupported format is specified. + """ + # Validate and process the specified formats + formats = format.replace(' ', '').split(',') + from auto_round.utils import supported_formats + for format_ in formats: + if format_ not in supported_formats: + logger.error(f"Unsupported format {format_}, please choose from {supported_formats}") + exit(-1) + + # only support to export afp8 + if self.act_bits <= 8: + if "fp8" not in self.act_data_type: + if len(formats) > 1 or "fake" not in formats: + logger.warning( + f"Currently only support to export auto_round format quantized model" + " with fp8 dtype activation for activation quantization." + " Change format to fake and save." + ) + formats = ["fake"] + else: + if len(formats) > 1 or "auto_round" not in formats: + logger.warning( + f"Currently only support to export auto_round format for W{self.bits}AFP8 model," + " change format to auto_round" + ) + formats = ["auto_round"] + + # If multiple formats are specified, enforce inplace=False + if len(formats) > 1: + inplace = False + inplace = kwargs.get("inplace", inplace) + kwargs.pop("inplace", None) + + # Determine if immediate packing is required + if (len(formats) == 1 and + ("awq" in formats[0] or "gptq" in formats[0] or "auto_round" in formats[0]) and + not self.has_qlayer_outside_block and inplace): # TODO: Support more formats + self.is_packing_immediate = True + + # Adjust format settings based on compatibility + for index in range(len(formats)): + format = formats[index] + if "auto_round" in format: + if (self.sym and ("gptq" not in format and "awq" not in format)) or self.bits == 3: + format = format.replace('auto_round', 'auto_round:auto_gptq') + formats[index] = format + + # Remove duplicates from formats list + def remove_duplicates(lst): + seen = set() + return [x for x in lst if not (x in seen or seen.add(x))] + + formats = remove_duplicates(formats) + self.formats = formats + + # # Check format compatibility + # self._check_format_compatibility(formats) + + # Perform model quantization + model, _ = self.quantize() + + # Save the quantized model in the specified formats + folders = [] + for format in formats: + if "gptq" in format and not self.sym: + logger.warning( + "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/4/8 bits)." + ) + save_format_ = format.replace(":", "-").replace("_", "-") + save_folder = os.path.join(output_dir, save_format_) if len(formats) > 1 else output_dir + self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs) + + folders.append(save_folder) + + return model, folders + + def dntc_sample(self, data_path): + ddim_step = 51 + t_mean = 0.4 + t_std = 0.4 + num_samples = 128 + t_i = np.random.normal(t_mean, t_std, num_samples) * (ddim_step-1) + t_i = np.clip(np.round(t_i), 0, ddim_step-1) + + dataset = DiffusionInputDataset(data_path) + x = dataset.xt_list + t = dataset.t_list + y = dataset.y_list + + st = np.zeros((250, 8, ddim_step)) + + calib_xt, calib_y, calib_t = [], [], [] + + for i in range(t_i.shape[0]): + ct = int(t_i[i]) + + while True: + c = np.random.randint(0, 250) + idx = np.random.randint(0, 8) + + if st[c][idx][ct] == 0: + st[c][idx][ct] = 1 + break + + j = ddim_step * 8 * c + (ddim_step-1-ct) * 8 + idx + calib_xt.append(x[j].unsqueeze(0)) + calib_y.append(y[j].unsqueeze(0)) + calib_t.append(t[j].unsqueeze(0)) + + cali_xt, cali_t, cali_y = torch.cat(calib_xt, dim=0), torch.cat(calib_t, dim=0), torch.cat(calib_y, dim=0) + + del(dataset) + del(x) + del(t) + del(y) + del(st) + gc.collect() + torch.cuda.empty_cache() + + return cali_xt, cali_t, cali_y + + def quant_module_refactor(self, module: torch.nn.Module): + """ + Recursively replace the normal layers (conv2D, conv1D, Linear etc.) to QuantModule + :param module: nn.Module with nn.Conv2d, nn.Conv1d, or nn.Linear in its children + :param weight_quant_params: quantization parameters like n_bits for weight quantizer + :param act_quant_params: quantization parameters like n_bits for activation quantizer + """ + for name, child_module in module.named_children(): + if isinstance(child_module, tuple(self.supported_types)): # nn.Conv1d + setattr(module, name, WrapperLinear(child_module, keys=self.serialization_keys, device=self.device)) + else: + self.quant_module_refactor(child_module) + + def quant_block_refactor(self, module: torch.nn.Module): + for name, child_module in module.named_children(): + if type(child_module) in self.specials: + if self.specials[type(child_module)] in [QuantBasicTransformerBlock, QuantAttnBlock]: + setattr(module, name, self.specials[type(child_module)](child_module, keys=self.serialization_keys, device=self.device)) + else: + setattr(module, name, self.specials[type(child_module)](child_module, keys=self.serialization_keys, device=self.device)) + else: + self.quant_block_refactor(child_module) + + def resume_cali_model(self): + pass + + def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): + for m in self.model.modules(): + if isinstance(m, (WrapperLinear, )): + m.set_quant_state(weight_quant, act_quant) + + def forward(self, x, timesteps=None, context=None): + return self.model(x, timesteps, context) + + def set_running_stat(self, running_stat: bool, sm_only=False): + for m in self.model.modules(): + if isinstance(m, QuantBasicTransformerBlock): + if sm_only: + m.attn1.act_quantizer_w.running_stat = running_stat + m.attn2.act_quantizer_w.running_stat = running_stat + else: + m.attn1.act_quantizer_q.running_stat = running_stat + m.attn1.act_quantizer_k.running_stat = running_stat + m.attn1.act_quantizer_v.running_stat = running_stat + m.attn1.act_quantizer_w.running_stat = running_stat + m.attn2.act_quantizer_q.running_stat = running_stat + m.attn2.act_quantizer_k.running_stat = running_stat + m.attn2.act_quantizer_v.running_stat = running_stat + m.attn2.act_quantizer_w.running_stat = running_stat + if isinstance(m, WrapperLinear) and not sm_only: + m.set_running_stat(running_stat) + + def recon_model(self, model): + """ + Block reconstruction. For the first and last layers, we can only apply layer reconstruction. + """ + global idx + for name, module in model.named_children(): + # logger.info(f"{name} {isinstance(module, BaseQuantBlock)}") + if name == 'output_blocks': + logger.info("Finished calibrating input and mid blocks, saving temporary checkpoint...") + in_recon_done = True + # torch.save(qnn.state_dict(), os.path.join(outpath, "ckpt.pth")) + if name.isdigit() and int(name) >= 9: + logger.info(f"Saving temporary checkpoint at {name}...") + # torch.save(self.model.state_dict(), os.path.join(outpath, "ckpt.pth")) + + if isinstance(module, WrapperLinear): + if module.ignore_reconstruction is True: + logger.info('Ignore reconstruction of layer {}'.format(name)) + continue + else: + logger.info('Reconstruction for layer {}'.format(name)) + # layer_reconstruction(qnn, module, **kwargs) + idx += 1 + print("idx: ", idx) + elif isinstance(module, BaseWrapperBlock): + if module.ignore_reconstruction is True: + logger.info('Ignore reconstruction of block {}'.format(name)) + continue + else: + logger.info('Reconstruction for block {}'.format(name)) + # block_reconstruction(qnn, module, **kwargs) + idx += 1 + print("idx: ", idx) + else: + self.recon_model(module) + + def unwrap_model(self): + pass + + def quantize(self): + """Quantize the model and return the quantized model along with layer configurations. + the entry of AutoRound. + + Returns: + The quantized model and layer configurations. + """ + + # model refactor + self.specials = get_specials(self.quant_act) + self.quant_module_refactor(self.model) + self.quant_block_refactor(self.model) + + # get cali data + cali_xs, cali_ts, cali_cs = self.dntc_sample(self.cali_data_path) + logger.info(f"Calibration data shape: {cali_xs.shape} {cali_ts.shape} {cali_cs.shape}") + + if self.resume_w: + # set-max + # blabla + self.resume_cali_model() # include init forward + # resume_cali_model(qnn, opt.cali_ckpt, cali_data, False, cond=opt.cond) + else: + # RTN initialization for weight quantization + # logger.info("Quantizing model weight using RTN...") + self.set_quant_state(True, False) # enable weight quantization, disable act quantization + _ = self.model(cali_xs[:8].to(self.device), cali_ts[:8].to(self.device), cali_cs[:8].to(self.device)) + logger.info("RTN quantizing has done!") + + if self.tune: + # Adaround tuning for weight quantization + logger.info("Doing weight calibration...") + self.recon_model(self.model) + self.set_quant_state(weight_quant=True, act_quant=False) + + if self.quant_act: + # RTN initialization for weight quantization + logger.info("Doing activation calibration...") + # Initialize activation quantization parameters + self.set_quant_state(True, True) + with torch.no_grad(): + inds = np.random.choice(cali_xs.shape[0], 8, replace=False) + _ = self.model(cali_xs[inds].to(self.device), cali_ts[inds].to(self.device), cali_cs[inds].to(self.device)) + if self.running_stat_a: + logger.info('Running stat for activation quantization') + inds = np.arange(cali_xs.shape[0]) + np.random.shuffle(inds) + self.set_running_stat(True, self.rs_sm_only_a) + for i in trange(int(cali_xs.size(0) / 8)): + _ = self.model(cali_xs[inds[i * 8:(i + 1) * 8]].cuda(), + cali_ts[inds[i * 8:(i + 1) * 8]].cuda(), + cali_cs[inds[i * 8:(i + 1) * 8]].cuda()) + self.set_running_stat(False, self.rs_sm_only_a) + + if self.tune: + # Adaround tuning for activation quantization + pass + + self.quantized = True + self.ldm.model.diffusion_model = self.model + return self.ldm, self.layer_config + + + def set_layerwise_config(self, layer_config): + """ + Sets the layer-wise configuration based on the provided `layer_config`. + By default, only quantize layers in blocks. + + Args: + layer_config (dict): The configuration dictionary for each layer containing various configuration options. + + Returns: + bool: Returns True if there are quantized layers outside the blocks (e.g., lm-head), + otherwise returns False. + """ + # List of configuration keys + keys = self.serialization_keys + + # Iterate through all modules in the model + # supported_type = tuple(self.supported_types) + (ResBlock, BasicTransformerBlock) + for n, m in self.model.named_modules(): + + if not isinstance(m, tuple(self.supported_types) + (ResBlock, BasicTransformerBlock)): + continue + + layer_config[n] = {} + + # Skip unsupported types + if isinstance(m, tuple(self.supported_types)): + for key in keys: + if hasattr(self, key): + layer_config[n][key] = getattr(self, key) + setattr(m, key, layer_config[n][key]) + elif isinstance(m, ResBlock): + layer_config[n]["split"] = 0 + elif isinstance(m, BasicTransformerBlock): + if hasattr(self, "sm_abit"): + layer_config[n]["sm_abit"] = getattr(self, "sm_abit") + layer_config[n]["sm_always_zero_a"] = True + + def register_act_max_hook(self, model): + def get_act_max_hook(module, input, output): + if isinstance(input, (tuple, list)): + input = input[0] + if not hasattr(module, "act_max"): + module.act_max = torch.abs(input).max().item() + else: + module.act_max = max(torch.abs(input).max().item(), module.act_max) + + hook_handles = [] + + for n, m in model.named_modules(): + if hasattr(m, "act_dynamic") and m.act_dynamic == False and check_to_quantized(m): + hook = m.register_forward_hook(get_act_max_hook) + hook_handles.append(hook) + return hook_handles + + def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): + """Save the quantized model to the specified output directory in the specified format. + + Args: + output_dir (str, optional): The directory to save the quantized model. Defaults to None. + format (str, optional): The format in which to save the model. Defaults to "auto_round". + inplace (bool, optional): Whether to modify the model in place. Defaults to True. + **kwargs: Additional keyword arguments specific to the export format. + + Returns: + object: The compressed model object. + """ + if self.processor is not None and not hasattr(self.processor, "chat_template"): + self.processor.chat_template = None + compressed_model = super().save_quantized( + output_dir=output_dir, format=format, inplace=inplace, processor=self.processor, **kwargs) + return compressed_model diff --git a/auto_round_diff/requirements.txt b/auto_round_diff/requirements.txt new file mode 100644 index 00000000..f2e8b86c --- /dev/null +++ b/auto_round_diff/requirements.txt @@ -0,0 +1,137 @@ +absl-py==2.2.2 +accelerate==1.0.1 +aiohappyeyeballs==2.4.4 +aiohttp==3.10.11 +aiosignal==1.3.1 +albumentations==0.4.3 +altair==5.4.1 +antlr4-python3-runtime==4.8 +async-timeout==5.0.1 +attrs==25.3.0 +blinker==1.8.2 +cachetools==5.5.2 +certifi==2025.1.31 +charset-normalizer==3.4.1 +click==8.1.8 +clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 +-e git+https://github.com/Taited/clip-score@f49cce25fd3aa896b05822d41ed2670bda23be50#egg=clip_score +contourpy==1.1.1 +cycler==0.12.1 +diffusers==0.3.0 +einops==0.3.0 +filelock==3.16.1 +fonttools==4.57.0 +frozenlist==1.5.0 +fsspec==2025.3.0 +ftfy==6.2.3 +future==1.0.0 +gitdb==4.0.12 +GitPython==3.1.44 +google-auth==2.39.0 +google-auth-oauthlib==1.0.0 +grpcio==1.70.0 +huggingface-hub==0.30.2 +idna==3.10 +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +imgaug==0.2.6 +importlib_metadata==8.5.0 +importlib_resources==6.4.5 +invisible-watermark==0.2.0 +Jinja2==3.1.6 +jsonschema==4.23.0 +jsonschema-specifications==2023.12.1 +kiwisolver==1.4.7 +kornia==0.6.9 +lazy_loader==0.4 +lmdb==1.3.0 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.7.5 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.1.0 +narwhals==1.35.0 +natsort==8.3.1 +networkx==3.1 +numpy==1.24.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvtx-cu12==12.1.105 +oauthlib==3.2.2 +omegaconf==2.1.1 +opencv-python==4.1.2.30 +opencv-python-headless==4.6.0.66 +packaging==24.2 +pandas==1.4.2 +Pillow==9.0.1 +pkgutil_resolve_name==1.3.10 +propcache==0.2.0 +protobuf==5.29.4 +psutil==7.0.0 +pudb==2019.2 +py-cpuinfo==9.0.0 +pyarrow==17.0.0 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pydeck==0.9.1 +pyDeprecate==0.3.1 +Pygments==2.19.1 +pyparsing==3.1.4 +python-dateutil==2.9.0.post0 +pytorch-lightning==1.4.2 +pytz==2025.2 +PyWavelets==1.4.1 +PyYAML==6.0 +-e git+https://github.com/QianSharphen/Q-T2I.git@b9d6118cbff2273b075460fa8b46d893420b6bd4#egg=q_diffusion +referencing==0.35.1 +regex==2024.11.6 +requests==2.32.3 +requests-oauthlib==2.0.0 +rich==13.9.4 +rpds-py==0.20.1 +rsa==4.9.1 +safetensors==0.5.3 +scikit-image==0.20.0 +scipy==1.9.1 +seaborn==0.13.2 +six==1.16.0 +smmap==5.0.2 +streamlit==1.40.1 +streamlit-drawable-canvas==0.8.0 +sympy==1.13.3 +-e git+https://github.com/CompVis/taming-transformers.git@3ba01b241669f5ade541ce990f7650a3b8f65318#egg=taming_transformers +tenacity==9.0.0 +tensorboard==2.14.0 +tensorboard-data-server==0.7.2 +test_tube==0.7.5 +tifffile==2023.7.10 +tokenizers==0.12.1 +toml==0.10.2 +torch==2.4.1 +torch-fidelity==0.3.0 +torchaudio==2.4.1 +torchmetrics==0.6.0 +torchvision==0.19.1 +tornado==6.4.2 +tqdm==4.64.0 +transformers==4.22.2 +triton==3.0.0 +typing_extensions==4.13.2 +urllib3==2.2.3 +urwid==2.6.16 +watchdog==4.0.2 +wcwidth==0.2.13 +Werkzeug==3.0.6 +yarl==1.15.2 +zipp==3.20.2 diff --git a/auto_round_diff/script/__init__.py b/auto_round_diff/script/__init__.py new file mode 100644 index 00000000..e3fdc07b --- /dev/null +++ b/auto_round_diff/script/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/auto_round_diff/script/diffusion.py b/auto_round_diff/script/diffusion.py new file mode 100644 index 00000000..386891c7 --- /dev/null +++ b/auto_round_diff/script/diffusion.py @@ -0,0 +1,613 @@ +import torch +import os +import gc, yaml +import cv2 +import sys +import numpy as np +import argparse +import datetime +import logging +from torch import autocast +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from imwatermark import WatermarkEncoder +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from itertools import islice +from contextlib import nullcontext +from einops import rearrange +from torchvision.utils import make_grid +import time +from ldm.util import instantiate_from_config +from auto_round_diff.utils import ( + clear_memory, + logger, + set_cuda_visible_devices, + get_device_and_parallelism + ) + + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--disable_deterministic_algorithms", action='store_true', + help="disable torch deterministic algorithms.") + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--prompts_path", + type=str, + nargs="?", + help="dir to write results to", + # required=True + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--laion400m", + action='store_true', + help="uses the LAION400M model", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=2, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=4, + help="how many samples to produce for each given prompt. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + # linear quantization configs + # parser.add_argument( + # "--ptq", action="store_true", help="apply post-training quantization" + # ) + parser.add_argument( + "--quant_act", action="store_true", + help="if to quantize activations when ptq==True" + ) + parser.add_argument( + "--weight_bits", + type=int, + default=8, + help="int bit for weight quantization", + ) + parser.add_argument( + "--w_quant_granularity", + type=str, + default="channel_wise", + help="weight quantization granularity", + ) + parser.add_argument( + "--w_group_size", + type=int, + default=128, + help="group size for weight group quantization" + ) + parser.add_argument( + "--weight_asym", + action="store_true", + help="asymmetric quantization for weight" + ) + parser.add_argument( + "--data_type_w", + type=str, + default='int', + help="data type for weight quantization" + ) + parser.add_argument( + "--w_scale_method", + type=str, + default='max', + help='algorithm for initializing weight quant params' + ) + parser.add_argument( + "--act_quant_granularity", + type=str, + default="channel_wise", + help="weight quantization granularity", + ) + parser.add_argument( + "--act_group_size", + type=int, + default=128, + help="group size for weight group quantization" + ) + parser.add_argument( + "--act_bits", + type=int, + default=8, + help="int bit for activation quantization", + ) + parser.add_argument( + "--act_asym", + action="store_true", + help="asymmetric quantization for activation" + ) + parser.add_argument( + "--act_dynamic", + action="store_true", + help="use dynamic quantization for activation" + ) + parser.add_argument( + "--data_type_act", + type=str, + default='int', + help="data type for activation quantization" + ) + parser.add_argument( + "--act_scale_method", + type=str, + default='max', + help='algorithm for initializing activation quant params ' + ) + parser.add_argument( + "--enable_quanted_input", + type=bool, + default=True, + help="enable quanted input for construction" + ) + parser.add_argument( + "--cali_st", type=int, default=1, + help="number of timesteps used for calibration" + ) + parser.add_argument( + "--batch_size", type=int, default=32, + help="batch size for qdiff reconstruction" + ) + parser.add_argument( + "--cali_n", type=int, default=1024, + help="number of samples for each timestep for qdiff reconstruction" + ) + parser.add_argument( + "--cali_iters_w", type=int, default=20000, + help="number of iterations for each qdiff reconstruction" + ) + parser.add_argument('--cali_iters_a', default=5000, type=int, + help='number of iteration for LSQ') + parser.add_argument('--cali_lr', default=4e-4, type=float, + help='learning rate for LSQ') + parser.add_argument('--cali_p', default=2.4, type=float, + help='L_p norm minimization for LSQ') + parser.add_argument( + "--cali_ckpt", type=str, + help="path for calibrated model ckpt" + ) + parser.add_argument( + "--tune", action="store_true", + help="tune weights and activation using the adaround/autoround algorithm" + ) + parser.add_argument( + "--w_lr", + type=float, + default=4e-5, + help="learning rate for weight tuning" + ) + parser.add_argument( + "--a_lr", + type=float, + default=4e-4, + help="learning rate for activation tuning" + ) + parser.add_argument( + "--resume", action="store_true", + help="resume the calibrated qdiff model" + ) + parser.add_argument( + "--resume_w", action="store_true", + help="resume the calibrated qdiff model weights only" + ) + parser.add_argument( + "--cond", action="store_true", + help="whether to use conditional guidance" + ) + parser.add_argument( + "--no_grad_ckpt", action="store_true", + help="disable gradient checkpointing" + ) + parser.add_argument( + "--split", action="store_true", + help="use split strategy in skip connection" + ) + parser.add_argument( + "--cali_data_path", + type=str, + required=True, + help="cali data for quant" + ) + parser.add_argument( + "--running_stat", action="store_true", + help="use running statistics for act quantizers" + ) + parser.add_argument( + "--rs_sm_only", action="store_true", + help="use running statistics only for softmax act quantizers" + ) + parser.add_argument( + "--sm_abit",type=int, default=8, + help="attn softmax activation bit" + ) + parser.add_argument( + "--verbose", action="store_true", + help="print out info like quantized model arch" + ) + parser.add_argument( + "--round_type", + type=str, + default="adaround", + help="algorithm for round operation" + ) + parser.add_argument( + "--device", + "--devices", + default="0", + type=str, + help="the device to be used for tuning. " + "Currently, device settings support CPU, GPU, and HPU." + "The default is set to cuda:0," + "allowing for automatic detection and switch to HPU or CPU." + "set --device 0,1,2 to use multiple cards.") + + args = parser.parse_args() + return args + +def set_logger(args): + + os.makedirs(args.outdir, exist_ok=True) + log_path = os.path.join(args.outdir, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) + os.makedirs(log_path) + + log_path = os.path.join(log_path, "run.log") + sh = logger.FileHandler(log_path) + + from auto_round_diff.utils import AutoRoundFormatter + sh.setFormatter(AutoRoundFormatter()) + logger.addHandler(sh) + +def load_model_from_config(config, ckpt, verbose=False): + logger.info(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + logger.info(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + logger.info("missing keys:") + logger.info(m) + if len(u) > 0 and verbose: + logger.info("unexpected keys:") + logger.info(u) + + model.cuda() + model.eval() + return model + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + +def sample(args, model): + os.makedirs(args.outdir, exist_ok=True) + outpath = os.path.join(args.outdir, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) + os.makedirs(outpath) + + logger.info("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "StableDiffusionV1" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + if args.plms: + print('plms') + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + batch_size = args.n_samples + n_rows = args.n_rows if args.n_rows > 0 else batch_size + if not args.from_file: + prompt = args.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + logger.info(f"reading prompts from {args.from_file}") + with open(args.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + # write config out + sampling_file = os.path.join(outpath, "sampling_config.yaml") + sampling_conf = vars(args) + with open(sampling_file, 'a+') as f: + yaml.dump(sampling_conf, f, default_flow_style=False) + if args.verbose: + logger.info("UNet model") + logger.info(model.model) + + start_code = None + if args.fixed_code: + start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=model.device) + + precision_scope = autocast if args.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(args.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if args.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [args.C, args.H // args.f, args.W // args.f] + samples_ddim, _ = sampler.sample(S=args.ddim_steps, + conditioning=c, + batch_size=args.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + eta=args.ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image = x_samples_ddim + # x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + if not args.skip_save: + for x_sample in x_checked_image_torch: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + + if not args.skip_grid: + all_samples.append(x_checked_image_torch) + + if not args.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + toc = time.time() + + logger.info(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + +def tune(args): + + ##must set this before import torch + set_cuda_visible_devices(args.device) + device_str, use_auto_mapping = get_device_and_parallelism(args.device) + + if not args.disable_deterministic_algorithms: + torch.use_deterministic_algorithms(True, warn_only=True) + print("'torch.use_deterministic_algorithms' is turned on by default for reproducibility, " \ + "and can be turned off by setting the '--disable_deterministic_algorithms' parameter.") + + # load model + config = OmegaConf.load(f"{args.config}") + model = load_model_from_config(config, f"{args.ckpt}") + + from auto_round_diff.diffusion.autoround_diffusion import AdaRoundUnetDiffusion + model = model.eval() + + if args.round_type == 'autoround': + # round = AutoRoundDiffusion( + # model, + # prompts_path=args.prompts_path, # cali prompts + # weight_bit=args.weight_bit, + # quant_granularity=args.quant_granularity, + # weight_sym=not args.weight_asym, + # batch_size=args.batch_size, + # cali_iters_w=args.cali_iters_w, + # lr=args.lr, + # amp=not args.disable_amp, + # enable_quanted_input=not args.disable_quanted_input, + # truncation=args.truncation, + # nsamples=args.nsamples, + # low_gpu_mem_usage=args.low_gpu_mem_usage, + # device=device_str, + # seed=args.seed, + # gradient_accumulate_steps=args.gradient_accumulate_steps, + # scale_dtype=args.scale_dtype, + # layer_config=layer_config, + # template=args.template, + # enable_minmax_tuning=not args.disable_minmax_tuning, + # act_bits=args.act_bits, + # quant_nontext_module=args.quant_nontext_module, + # not_use_best_mse=args.not_use_best_mse, + # to_quant_block_names=args.to_quant_block_names, + # enable_torch_compile=enable_torch_compile, + # device_map=args.device_map, + # model_kwargs=model_kwargs + # ) + # model, _ = autoround.quantize() + # round = AutoRoundDiffusion() + pass + elif args.round_type == 'adaround': + round = AdaRoundUnetDiffusion( + model, + prompts_path=args.prompts_path, # cali prompts + weight_bits=args.weight_bits, + w_quant_granularity=args.w_quant_granularity, + w_group_size=args.w_group_size, + sym_w=not args.weight_asym, + data_type_w=args.data_type_w, + w_scale_method=args.w_scale_method, + tune=args.tune, + batch_size=args.batch_size, + cali_iters_w=args.cali_iters_w, + quant_act=args.quant_act, + act_bits=args.act_bits, + act_quant_granularity=args.w_quant_granularity, + act_group_size=args.w_group_size, + sym_act=not args.act_asym, + act_dynamic=args.act_dynamic, + data_type_act=args.data_type_act, + act_scale_method=args.act_scale_method, + running_stat=args.running_stat, + sm_abit=args.sm_abit, + cali_iters_a=args.cali_iters_a, + cali_n=args.cali_n, + cali_data_path=args.cali_data_path, + a_lr=args.a_lr, + rs_sm_only=args.rs_sm_only, + w_lr=args.w_lr, + enable_quanted_input=args.enable_quanted_input, + device=device_str, + seed=args.seed, + split=args.split, + resume_w=args.resume_w + ) + else: + raise NotImplementedError("This round algorithm has not been implemented yet.") + + model, _ = round.quantize() + + sample(args, model) + + model.eval() + clear_memory() + + \ No newline at end of file diff --git a/auto_round_diff/sign_sgd.py b/auto_round_diff/sign_sgd.py new file mode 100644 index 00000000..e4b94810 --- /dev/null +++ b/auto_round_diff/sign_sgd.py @@ -0,0 +1,389 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional + +# From PyTorch: +# +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) +# +# From Caffe2: +# +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. +# +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. +# +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. +# +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. +# +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain +# +# All contributions by Cruise LLC: +# Copyright (c) 2022 Cruise LLC. +# All rights reserved. +# +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. +# +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. +# +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +__all__ = ["SignSGD", "sgd"] + + +class _RequiredParameter(object): + """Singleton class representing a required parameter for an Optimizer.""" + + def __repr__(self): + return "" + + +required = _RequiredParameter() + + +def _use_grad_for_differentiable(func): + def _use_grad(self, *args, **kwargs): + prev_grad = torch.is_grad_enabled() + try: + torch.set_grad_enabled(self.defaults["differentiable"]) + ret = func(self, *args, **kwargs) + finally: + torch.set_grad_enabled(prev_grad) + return ret + + return _use_grad + + +class SignSGD(Optimizer): + r"""Implements stochastic gradient descent (optionally with momentum). + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) + \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ + &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, + \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ + &\hspace{10mm}\textbf{if} \: t > 1 \\ + &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ + &\hspace{10mm}\textbf{else} \\ + &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ + &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ + &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ + &\hspace{10mm}\textbf{else} \\[-1.ex] + &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] + &\hspace{5mm}\textbf{else} \\[-1.ex] + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + maximize (bool, optional): maximize the params based on the objective, instead of + minimizing (default: False) + foreach (bool, optional): whether foreach implementation of optimizer + is used (default: None) + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et. al. and implementations in some other frameworks. + + Considering the specific case of Momentum, the update can be written as + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ + p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, + \end{aligned} + + where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the + parameters, gradient, velocity, and momentum respectively. + + This is in contrast to Sutskever et. al. and + other frameworks which employ an update of the form + + .. math:: + \begin{aligned} + v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ + p_{t+1} & = p_{t} - v_{t+1}. + \end{aligned} + + The Nesterov version is analogously modified. + """ + + def __init__( + self, + params, + lr=required, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + *, + maximize=False, + foreach: Optional[bool] = None, + differentiable=False + ): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(SignSGD, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + group.setdefault("maximize", False) + group.setdefault("foreach", None) + group.setdefault("differentiable", False) + + @_use_grad_for_differentiable + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + d_p_list = [] + momentum_buffer_list = [] + has_sparse_grad = False + + for p in group["params"]: + if p.grad is not None: + params_with_grad.append(p) + d_p_list.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + state = self.state[p] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + sgd( + params_with_grad, + d_p_list, + momentum_buffer_list, + weight_decay=group["weight_decay"], + momentum=group["momentum"], + lr=group["lr"], + dampening=group["dampening"], + nesterov=group["nesterov"], + maximize=group["maximize"], + has_sparse_grad=has_sparse_grad, + foreach=group["foreach"], + ) + + # update momentum_buffers in state + for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list): + state = self.state[p] + state["momentum_buffer"] = momentum_buffer + + return loss + + +def sgd( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = None, + foreach: bool = None, + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool +): + r"""Functional API that performs SGD algorithm computation. + + See :class:`~torch.optim.SGD` for details. + """ + + if foreach is None: + # Placeholder for more complex foreach logic to be added when value is not set + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError("torch.jit.script not supported with foreach optimizers") + + # if foreach and not torch.jit.is_scripting(): + # func = _multi_tensor_sgd + # else: + func = _single_tensor_sgd + + func( + params, + d_p_list, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + ) + + +def _single_tensor_sgd( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool +): + for i, param in enumerate(params): + d_p = d_p_list[i] if not maximize else -d_p_list[i] + + if weight_decay != 0: + d_p = d_p.add(param, alpha=weight_decay) + + if momentum != 0: + buf = momentum_buffer_list[i] + + if buf is None: + buf = torch.clone(d_p).detach() + momentum_buffer_list[i] = buf + else: + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + + param.add_(torch.sign(d_p), alpha=-lr) diff --git a/auto_round_diff/special_model_handler.py b/auto_round_diff/special_model_handler.py new file mode 100644 index 00000000..f602f9b3 --- /dev/null +++ b/auto_round_diff/special_model_handler.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size + +SUPPORT_ONLY_TEXT_MODELS = [ + "phi3_v", + "cogvlm2", + "llava", + "qwen2_vl", + "deepseek_vl_v2", + "chatglm", + "idefics3" +] + +SPECIAL_SHARED_CACHE_KEYS = { + "Gemma3ForConditionalGeneration": ("position_embeddings_global", "position_embeddings_local")} +SPECIAL_SHARED_CACHE_KEYS["MiniMaxText01ForCausalLM"] = ("slope_rate",) + + +def _handle_special_model(model): + if model.config.model_type == "deepseek_vl_v2": + from functools import partial + model.forward = partial(_deepseek_vl2_forward, model) + return model + + +def _get_deepseek_vl2_multimodal_block(model, quant_vision=False): + model.forward = model.language.forward + block_names = [] + if quant_vision: + block_names.append([f"vision.blocks.{i}" for i in range(len(model.vision.blocks))]) + block_names.append([f"projector.layers.{i}" for i in range(len(model.projector.layers))]) + block_names.append([f"language.model.layers.{i}" for i in range(len(model.language.model.layers))]) + return block_names + + +SPECIAL_MULTIMODAL_BLOCK = { + "deepseek_vl_v2": _get_deepseek_vl2_multimodal_block +} + + +def _deepseek_vl2_forward( + model, + input_ids=None, + + position_ids=None, + attention_mask=None, + past_key_values=None, + inputs_embeds=None, + + images=None, + images_seq_mask=None, + images_spatial_crop=None, + + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + **kwargs +): + inputs_embeds = model.prepare_inputs_embeds( + input_ids=input_ids, + images=images, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + ) + return model.language( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position) + + +def check_mllm_model_batch(model, batch_size, gradient_accumulate_steps=1): + """ + Checks model configuration to determine if it's necessary to limit bs to avoid potential input shape mismatches. + """ + for key in mllms_with_limited_bs: + if hasattr(model, "config") and key in model.config.model_type and batch_size != 1: + accumulate_steps = batch_size * gradient_accumulate_steps + print("To avoid the tensor concat mismatch problem, modified parameters to " \ + f"batch_size=1. As an alternative, set the gradient_accumulate_steps={accumulate_steps}") + return 1, accumulate_steps + return batch_size, gradient_accumulate_steps diff --git a/auto_round_diff/utils.py b/auto_round_diff/utils.py new file mode 100644 index 00000000..9c7563c9 --- /dev/null +++ b/auto_round_diff/utils.py @@ -0,0 +1,1465 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import os +import sys +import subprocess +from collections import UserDict +import re +import cpuinfo +import psutil +import torch +from torch.amp import autocast + +from functools import lru_cache +from packaging import version +import gc +from .special_model_handler import SPECIAL_MULTIMODAL_BLOCK, SPECIAL_SHARED_CACHE_KEYS +import transformers +from auto_round.export.export_to_gguf.config import GGUF_CONFIG + +shared_cache_keys = ("position_ids", "cache_position", "position_embeddings") + +supported_formats = ( + "auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq", "auto_round:gptqmodel", "auto_round:auto_awq", + "itrex", "itrex_xpu", "fake" +) + +supported_formats = supported_formats + tuple(GGUF_CONFIG.keys()) + +supported_layer_types = (torch.nn.Linear, transformers.modeling_utils.Conv1D, torch.nn.Conv2d, torch.nn.Conv1d) + +supported_dtypes = ("int", "mx_fp", "fp", "nv_fp") + + +def infer_bits_by_data_type(data_type: str): + for supported_dtype in supported_dtypes: + if data_type.startswith(supported_dtype) and len(data_type) > len(supported_dtype): + ##first check the following two bits + suc_2str = data_type[len(supported_dtype):len(supported_dtype) + 2] + if str.isdigit(suc_2str): + return int(suc_2str) + if str.isdigit(data_type[len(supported_dtype)]): + return int(data_type[len(supported_dtype)]) + return 16 + + +@lru_cache(None) +def warning_once(self, msg: str): + self.warning(msg) + + +class AutoRoundFormatter(logging.Formatter): + grey = "\x1b[38;20m" + yellow = "\x1b[33;1m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + reset = "\x1b[0m" + _format = "%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s" + + FORMATS = { + logging.DEBUG: grey + _format + reset, + logging.INFO: grey + _format + reset, + logging.WARNING: yellow + _format + reset, + logging.ERROR: bold_red + _format + reset, + logging.CRITICAL: bold_red + _format + reset + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, "%Y-%m-%d %H:%M:%S") + return formatter.format(record) + + +logging.Logger.warning_once = warning_once +logger = logging.getLogger("autoround") +logger.setLevel(logging.INFO) +logger.propagate = False +# fh = logging.StreamHandler() +# fh.setFormatter(AutoRoundFormatter()) +# logger.addHandler(fh) + +import importlib +import transformers + + +class LazyImport(object): + """Lazy import python module till use.""" + + def __init__(self, module_name): + """Init LazyImport object. + + Args: + module_name (string): The name of module imported later + """ + self.module_name = module_name + self.module = None + + def __getattr__(self, name): + """Get the attributes of the module by name.""" + try: + self.module = importlib.import_module(self.module_name) + mod = getattr(self.module, name) + except: + spec = importlib.util.find_spec(str(self.module_name + "." + name)) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + def __call__(self, *args, **kwargs): + """Call the function in that module.""" + function_name = self.module_name.split(".")[-1] + module_name = self.module_name.split(f".{function_name}")[0] + self.module = importlib.import_module(module_name) + function = getattr(self.module, function_name) + return function(*args, **kwargs) + + +auto_gptq = LazyImport("auto_gptq") +htcore = LazyImport("habana_frameworks.torch.core") + + +def is_optimum_habana_available(): + from transformers.utils.import_utils import is_optimum_available + + return is_optimum_available() and importlib.util.find_spec("optimum.habana") is not None + + +def get_module(module, key): + """Get module from model by key name. + + Args: + module (torch.nn.Module): original model + key (str): module name to be replaced + """ + name_list = key.split(".") + for name in name_list: + module = getattr(module, name, None) + return module + + +def set_module(model, key, new_module): + """Set new module into model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + new_module (torch.nn.Module): new module to be inserted + """ + module = model + name_list = key.split(".") + for name in name_list[:-1]: + if hasattr(module, name): + module = getattr(module, name) + setattr(module, name_list[-1], new_module) + + +def get_scale_shape(weight, group_size): + """Computes the shape of the scale tensor for quantization based on the weight tensor and group size. + + Args: + weight (torch.Tensor): The weight tensor of the layer. + group_size (int): The size of the groups for quantization. + + Returns: + The shape of the scale tensor to be used for quantization. + """ + if group_size == -1 or weight.shape[1] < group_size: + shape = weight.shape[0] + else: + shape = weight.shape[0] * ((weight.shape[1] + group_size - 1) // group_size) + + return shape + + +def unsupport_meta_device(model): + """Checks if the model is a valid model for auto_round. + + Args: + model: The model to be checked. + + Returns: + bool: True if the model is valid, False otherwise. + """ + target_device = None + for param in model.parameters(): + if target_device is None: + target_device = param.device + if param.device != target_device: + if param.device.type == 'meta' or target_device.type == 'meta': + return True + if target_device.type == 'meta': + if hasattr(model, "path"): + return False + else: + return True + return False + + +def to_device(input, device=torch.device("cpu")): + """Moves input data to the specified device. + + Args: + input: The input data to be moved. + device: The target device. + + Returns: + The input data on the specified device. + """ + if input is None: + return None + if isinstance(input, torch.Tensor): + return input.to(device) + if isinstance(input, dict) or isinstance(input, UserDict): + for inp in input.keys(): + input[inp] = to_device(input[inp], device) + + elif isinstance(input, list) or isinstance(input, tuple): + if len(input) == 0: + return input + input_res = [] + for inp in input: + input_res.append(to_device(inp, device)) + if isinstance(input, tuple): + input_res = tuple(input_res) + input = input_res + + return input + + +def mv_module_from_gpu(module, low_cpu_mem_usage=False): + """Moves module from gpu to cpu or meta if low_cpu_mem_usage is true. + + Args: + module: The module to be moved. + low_cpu_mem_usage: Whether to use low CPU memory. If true, move module to meta. + + Returns: + The module on the specified device. + """ + if hasattr(module, "device"): + target_device = "meta" if low_cpu_mem_usage else "cpu" + if module.device.type == target_device: + return module + else: + return module.to(target_device) + else: + if low_cpu_mem_usage: + return module.to('meta') + else: + return module.to('cpu') + + +def to_dtype(input, dtype=torch.float32): + """Moves input data to the specified data type. + + Args: + input: The input data to be moved. + dtype: The target data type. + + Returns: + The input data on the specified data type. + """ + if input is None: + return None + if isinstance(input, torch.Tensor): + return input.to(dtype) + if isinstance(input, dict) or isinstance(input, UserDict): + for inp in input.keys(): + input[inp] = to_dtype(input[inp], dtype) + + elif isinstance(input, list) or isinstance(input, tuple): + if len(input) == 0: + return input + input_res = [] + for inp in input: + input_res.append(to_dtype(inp, dtype)) + if isinstance(input, tuple): + input_res = tuple(input_res) + input = input_res + + return input + + +def check_is_cpu(device): + """Check if the device is a CPU. + + Args: + device: The device to be checked. + + Returns: + bool: True if the device is a CPU, False otherwise. + """ + return device == torch.device("cpu") or device == "cpu" + + +def get_common_prefix(paths): + # Split each path into components and find the common prefix + split_paths = [path.split('.') for path in paths] + common_prefix = split_paths[0] + for path in split_paths[1:]: + common_prefix = [comp for comp, other in zip(common_prefix, path) if comp == other] + return '.'.join(common_prefix) + + +def extract_block_names_to_str(quant_block_list): + if not isinstance(quant_block_list, (list, tuple)): + return None + # Extract common prefix for each list + prefixes = [get_common_prefix(blocks) for blocks in quant_block_list] + # Join prefixes into a single string + return ','.join(prefixes) + + +def find_matching_blocks(model, all_blocks, to_quant_block_names): + """ + Find and return matching blocks in the model based on to_quant_block_names. + + Args: + model: The model (not used in this specific function but kept for completeness). + all_blocks: List of lists, where each inner list contains full block names in the model. + to_quant_block_names: Comma-separated string of target block names to match. + + Returns: + target_blocks: List of lists containing full paths of matching blocks in the model. + """ + if not to_quant_block_names: + return all_blocks + to_quant_block_list = to_quant_block_names + if isinstance(to_quant_block_names, list) or isinstance(to_quant_block_names, tuple): + return to_quant_block_names + if isinstance(to_quant_block_names, str): + to_quant_block_list = [name.strip() for name in to_quant_block_names.split(",")] + target_blocks = [] + for block_list in all_blocks: + matched_sublist = [] + for name in to_quant_block_list: + matches = [block for block in block_list if re.search(name, block)] + if matches: + matched_sublist.extend(matches) + if matched_sublist: + target_blocks.append(matched_sublist) + if not target_blocks: + raise ValueError("No block names matched. Please check the input for to_quant_block_name," \ + "or set to_quant_block_name to None to automatically match quantizable blocks.") + return target_blocks + + +def get_block_names(model, quant_vision=False): + """Get the block names for transformers-like networks. + + Args: + model: The model. + + Returns: + block_names: A list whose elements are list of block's layer names + """ + + def _get_llm_block_names(model): + block_names = [] + target_modules = [] + for n, m in model.named_modules(): + if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: + target_modules.append((n, m)) + break ## only find the first modulelist, may be not robust + for i, target_m in enumerate(target_modules): + block_names.append([]) + for n, m in target_m[1].named_children(): + block_names[i].append(target_m[0] + "." + n) + return block_names + + def _get_vlm_block_names(model, quant_vision=False): + if hasattr(model, "config") and model.config.model_type in SPECIAL_MULTIMODAL_BLOCK.keys(): + return SPECIAL_MULTIMODAL_BLOCK.get(model.config.model_type)(model, quant_vision=quant_vision) + block_names = [] + target_modules = [] + vision_blocks_tuple = ("vision", "visual", "image", "img") + last_block_name = "" + for n, m in model.named_modules(): + if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: + if quant_vision or all(key not in n.lower() for key in (vision_blocks_tuple)): + if last_block_name and last_block_name in n: + continue + target_modules.append((n, m)) + last_block_name = n + for i, target_m in enumerate(target_modules): + block_names.append([]) + for n, m in target_m[1].named_children(): + block_names[i].append(target_m[0] + "." + n) + return block_names + + if quant_vision or not is_pure_text_model(model): + return _get_vlm_block_names(model, quant_vision=quant_vision) + else: + return _get_llm_block_names(model) + + +def collect_best_params(block): + params = {} + for n, m in block.named_modules(): + if hasattr(m, "orig_layer"): + params[n] = {} + for key in m.params.keys(): + params[n][key] = copy.deepcopy(m.params[key].data) + return params + + +def block_forward(block, input_ids, input_others, amp=False, amp_dtype=torch.float16, device=torch.device("cpu")): + """Performs a forward pass through a block with the given inputs. + + Args: + block: The block to perform the forward pass on. + input_ids: The input IDs. + input_others: A dictionary containing other input data. + amp: A boolean indicating whether to use automatic mixed precision. + amp_dtype: The data type for automatic mixed precision. + device: The target device. + + Returns: + output: The output of the forward pass. + """ + if input_ids.device != device: + input_ids = to_device(input_ids, device) + input_others = to_device(input_others, device) + input_tuple = input_others.pop("positional_inputs", None) + if "alibi" in input_others.keys() and input_others["alibi"] is not None: + alibi = input_others["alibi"] + input_others["alibi"] = alibi.reshape(-1, alibi.shape[2], alibi.shape[3]) + if amp: + with autocast(device_type=device.split(":")[0], dtype=amp_dtype): # pragma: no cover + output = block(input_ids, *input_tuple, **input_others) + else: + output = block(input_ids, *input_tuple, **input_others) + if isinstance(output, list) or isinstance(output, tuple): + output = output[0] + return output + + +def check_to_quantized(config): + """Checks if the configuration is valid for quantization. + + Args: + config (dict or object): The configuration to check. It can be either a + dictionary with a 'bits' key or an object with a 'bits' attribute. + + Returns: + bool: True if the configuration is valid for quantization (bits <= 8), + False otherwise. + """ + if isinstance(config, dict): + bits = int(config.get("bits", 16)) + act_bits = int(config.get("act_bits", 16)) + else: + bits = int(config.bits) if hasattr(config, "bits") else 16 + act_bits = int(config.act_bits) if hasattr(config, "act_bits") else 16 + + return bits <= 8 or act_bits <= 8 + + +def detect_device_count(): + """Detects the number of available computation devices. + + This function checks if CUDA is available. If it is, it returns the count + of available CUDA devices. If not, it attempts to import the Habana + device framework to return the count of Habana devices. If the import + fails or no devices are found, it returns 0. + + Returns: + int: The number of available devices (CUDA or Habana). + """ + if torch.cuda.is_available(): + return torch.cuda.device_count() + else: + try: + import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401 + return hthpu.device_count() + except ImportError: + return 0 + + +def detect_device(device=None): + """Detects the appropriate computation device. + + This function determines the device to use for computations. It can take + a specific device index or default to 'auto'. The function checks for + available devices in the following order: CUDA, Habana, and finally CPU. + + Args: + device (str, int, or torch.device, optional): The desired device. + If 'auto' or None, the function will determine the best device + automatically. + + Returns: + str: The device to use for computations, formatted as a string. + """ + + def is_valid_digit(s): + try: + num = int(s) + return 0 <= num + except: + return False + + dev_idx = None + if is_valid_digit(device): + dev_idx = int(device) + device = "auto" + if device is None or device == "auto": + if torch.cuda.is_available(): + device = torch.device("cuda") + # logger.info("Using GPU device") + elif is_optimum_habana_available(): # pragma: no cover + device = torch.device("hpu") + # logger.info("Using HPU device") + elif torch.xpu.is_available(): # pragma: no cover + device = torch.device("xpu") + # Use CPU as a fallback + else: + device = torch.device("cpu") + # logger.info("Using CPU device") + if dev_idx is not None and str(device) != "cpu": + device = str(device) + f":{dev_idx}" + return str(device) + elif isinstance(device, torch.device): + device = str(device) + return device + + +class CpuInfo(object): + """Get CPU Info.""" + + def __init__(self): + """Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket.""" + self._bf16 = False + self._vnni = False + info = cpuinfo.get_cpu_info() + if "arch" in info and "X86" in info["arch"]: + cpuid = cpuinfo.CPUID() + max_extension_support = cpuid.get_max_extension_support() + if max_extension_support >= 7: + ecx = cpuid._run_asm( + b"\x31\xC9", # xor ecx, ecx + b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\x89\xC8" b"\xC3", # mov eax, 7 # cpuid # mov ax, cx # ret + ) + self._vnni = bool(ecx & (1 << 11)) + eax = cpuid._run_asm( + b"\xB9\x01\x00\x00\x00", # mov ecx, 1 + b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret + ) + self._bf16 = bool(eax & (1 << 5)) + if "arch" in info and "ARM" in info["arch"]: # pragma: no cover + self._sockets = 1 + else: + self._sockets = self.get_number_of_sockets() + self._cores = psutil.cpu_count(logical=False) + self._cores_per_socket = int(self._cores / self._sockets) + + @property + def bf16(self): + """Get whether it is bf16.""" + return self._bf16 + + @property + def vnni(self): + """Get whether it is vnni.""" + return self._vnni + + @property + def cores_per_socket(self): + """Get the cores per socket.""" + return self._cores_per_socket + + def get_number_of_sockets(self) -> int: + """Get number of sockets in platform.""" + cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l" + if psutil.WINDOWS: + cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"' + elif psutil.MACOS: # pragma: no cover + cmd = "sysctl -n machdep.cpu.core_count" + + with subprocess.Popen( + args=cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=False, + ) as proc: + proc.wait() + if proc.stdout: + for line in proc.stdout: + return int(line.decode("utf-8", errors="ignore").strip()) + return 0 + + +def is_local_path(path): + """Checks if a given path exists locally. + + Args: + path (str): The path to check. + + Returns: + bool: True if the path exists locally, False otherwise. + """ + format_list = ("json", "txt",) + flag = None + for x in format_list: + flag = True if x in path else flag + return flag and os.path.exists(path) + + +def convert_dtype_str2torch(str_dtype): + """Converts a string dtype to its corresponding PyTorch dtype. + + Args: + str_dtype (str): The string representation of the dtype. + + Returns: + torch.dtype: The PyTorch dtype. + + Raises: + AssertionError: If the input str_dtype is unsupported. + """ + if isinstance(str_dtype, torch.dtype) or str_dtype is None: + return str_dtype + if str_dtype == "int8": + return torch.int8 + elif str_dtype == "fp32" or str_dtype == "float32" or str_dtype == "auto": + return torch.float + elif str_dtype == "fp16" or str_dtype == "float16": + return torch.float16 + elif str_dtype == "bf16" or str_dtype == "bfloat16": + return torch.bfloat16 + else: + assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype) + + +def convert_dtype_torch2str(dtype): + """Converts a PyTorch dtype to its corresponding string representation. + + Args: + dtype: PyTorch dtype or str. The dtype to convert. + + Returns: + str: The string representation of the dtype. + + Raises: + AssertionError: If the input dtype is unsupported. + """ + if isinstance(dtype, str) or dtype is None: + return dtype + if dtype == torch.int8: + return "int8" + elif dtype == torch.float: + return "fp32" + elif dtype == torch.float16: + return "fp16" + elif dtype == torch.bfloat16: + return "bf16" + elif isinstance(dtype, str) and dtype in ["int8", "fp32", "fp16", "bf16"]: + return dtype + else: + assert False, "Unsupported pytorch dtype {} to str dtype".format(dtype) + + +def convert_dtype_torch2str_hf(dtype): + """Converts a PyTorch dtype to its corresponding huggingface string dtype, e.g. torch.float32 -> 'float32'. + + Args: + dtype: PyTorch dtype or str. The dtype to convert. + + Returns: + str: The string representation of the dtype. + + Raises: + AssertionError: If the input str_dtype is unsupported. + """ + if dtype is None: + return dtype + if isinstance(dtype, str): + if "float" not in dtype and "int" not in dtype: + dtype = convert_dtype_str2torch(dtype) + else: + return dtype + str_dtype = str(dtype) + if "." not in str_dtype: + assert False, "Unsupported pytorch dtype {} to huggingface str dtype".format(dtype) + str_dtype = str_dtype.split(".")[1] + return str_dtype + + +def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): + """Checks the availability of memory on the specified device for processing inputs using a given weight tensor. + + Args: + device (str): The device type ('cuda' for GPU or 'hpu' for HPU). + inputs (torch.Tensor): Input tensor. + weight (torch.Tensor): Weight tensor. + org_seqlen (int): Original sequence length. + org_bs (int): Original batch size. + + Returns: + tuple: A tuple containing availability status (bool), modified sequence length (int), + and modified batch size (int). + """ + weight_memory = weight.numel() * weight.element_size() + if "cuda" in device: + current_gpu_index = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory + used_memory = torch.cuda.memory_allocated(current_gpu_index) + free_space = total_memory - used_memory + elif "hpu" in device: # pragma: no cover + current_hpu_index = torch.hpu.current_device() + free_space = torch.hpu.memory_reserved(current_hpu_index) + else: + return True, org_seqlen, org_bs + + free_space = free_space - weight_memory * 10 # for min_max_scale & grad usage + seqlen = org_seqlen + bs = org_bs + in_feature = weight.shape[1] + out_feature = weight.shape[0] + while seqlen >= 128: + input_size = bs * seqlen * in_feature + output_size = bs * seqlen * out_feature + input_output_memory = 2 * (input_size * inputs.element_size() + output_size * inputs.element_size()) + if input_output_memory < free_space: + return True, seqlen, bs + seqlen = seqlen // 2 + bs = 1 + + return False, seqlen, bs + + +def get_layer_names_in_block(model, supported_types=(torch.nn.Linear, + transformers.modeling_utils.Conv1D), quant_block_list=None): + """Retrieves the names of layers within each block of the model. + + Returns: + list: A list of strings, where each string is the name of a layer + within a block of the model. + """ + for n, m in model.named_modules(): + if isinstance(m, supported_types): + m.tmp_name = n + layers_in_block = [] + if bool(quant_block_list): + all_blocks = quant_block_list + else: + all_blocks = get_block_names(model) + for block_names in all_blocks: + for block_name in block_names: + block = get_module(model, block_name) + for n, m in block.named_modules(): + if hasattr(m, "tmp_name"): + layers_in_block.append(m.tmp_name) + for n, m in model.named_modules(): + if hasattr(m, "tmp_name"): + delattr(m, "tmp_name") + return layers_in_block + + +def is_autoround_exllamav2_available(): + """Checks if the AutoRound ExLlamaV2 kernels are available. + + Returns: + bool: + True if the AutoRound ExLlamaV2 kernels are available, False otherwise. + """ + res = True + try: + from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix + except ImportError as e: + res = False + return res + + +@lru_cache(None) +def is_hpu_supported(): # pragma: no cover + try: + import habana_frameworks.torch.core as htcore # pylint: disable=E0401 + except ImportError as e: + return False + return True + + +def get_library_version(library_name): + from packaging.version import Version + python_vesion = Version(sys.version.split()[0]) + if python_vesion < Version("3.8"): + import warnings + warnings.filterwarnings('ignore', category=DeprecationWarning) + import pkg_resources # pylint: disable=E0401 + try: + version = pkg_resources.get_distribution(library_name).version + return version + except pkg_resources.DistributionNotFound: + return f"{library_name} is not installed" + else: + import importlib.metadata # pylint: disable=E0401 + try: + version = importlib.metadata.version(library_name) + return version + except importlib.metadata.PackageNotFoundError: + return f"{library_name} is not installed" + + +def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False): + """ + Configures and returns a QuantLinear class based on the specified backend and parameters. + + Args: + backend (str): The backend to be used for quantization. Supported values include "qigen", "triton", "marlin", + "exllama", and "cuda". + bits (int, optional): The number of bits for quantization. Default is 4. + group_size (int, optional): The group size for quantization. Default is 128. + sym (bool, optional): Flag indicating whether to use symmetric quantization. Default is False. + + Returns: + class: The dynamically imported QuantLinear class configured according to the specified parameters. + """ + use_triton = True + if bits not in [2, 4, 8]: + use_triton = False + disable_exllamav2 = True + disable_exllamav1 = False + disable_marlin = True + use_qigen = False + if "qigen" in backend: + use_triton = False + use_qigen = True + elif "triton" in backend: + use_triton = True + elif "marlin" in backend and sym: + use_triton = False + disable_marlin = False + elif "exllama" in backend: ##need v1 code to export + use_triton = True ##same with triton + disable_marlin = True + elif "cuda" in backend: + use_triton = False + disable_marlin = True + disable_exllamav2 = True + disable_exllamav1 = True + if use_triton: + from auto_round.export.export_to_autogptq.qlinear_triton import QuantLinear + return QuantLinear + try: + import auto_gptq # pylint: disable=E0401 + except: + logger.error(f"please install auto_gptq via 'pip install auto-gptq' to support exporting to {backend}") + exit() + + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear # pylint: disable=E0401 + version = get_library_version("auto_gptq") + from packaging.version import Version + if Version(version) < Version("0.7.2"): + QuantLinear = dynamically_import_QuantLinear( + use_triton=use_triton, + desc_act=False, + group_size=group_size, + bits=bits, + disable_exllama=disable_exllamav1, + disable_exllamav2=disable_exllamav2, + use_qigen=use_qigen, + disable_marlin=disable_marlin, + ) + else: + QuantLinear = dynamically_import_QuantLinear( # pylint: disable=E1123 + use_triton=use_triton, + desc_act=False, + group_size=group_size, + bits=bits, + disable_exllama=disable_exllamav1, + disable_exllamav2=disable_exllamav2, + use_qigen=use_qigen, + use_marlin=not disable_marlin, + ) + return QuantLinear + + +def _clear_memory_for_cpu_and_cuda(tensor=None): + if isinstance(tensor, list): + for i in range(len(tensor)): + tensor[i] = None + if tensor is not None: + del tensor + gc.collect() + torch.cuda.empty_cache() + + +def clear_memory(tensor=None): + if is_hpu_supported(): + # hpu does not have empty_cache + return + else: + _clear_memory_for_cpu_and_cuda(tensor) + + +def compare_versions(v1, v2): + return version.parse(v1) >= version.parse(v2) + + +def torch_version_at_least(version_string): + return compare_versions(torch.__version__, version_string) + + +TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE = torch_version_at_least("2.5.99") +TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") +TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") +TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") + + +# Note on HPU usage: +# There are two modes available for enabling auto-round on HPU: +# 1. Compile Mode +# 1) Use PyTorch version ≥ 2.4 (Intel® Gaudi® v1.18 or later) +# 2) Set `PT_HPU_LAZY_MODE=0` and `PT_ENABLE_INT64_SUPPORT=1` +# The compile mode can speed up quantization process but still in experimental stage. +# 2. Lazy Mode (By default) + + +def is_hpu_lazy_mode(): + return os.getenv("PT_HPU_LAZY_MODE") != "0" + + +def _use_hpu_compile_mode(): + return TORCH_VERSION_AT_LEAST_2_4 and not is_hpu_lazy_mode() + + +def compile_func_on_hpu(func): + if _use_hpu_compile_mode(): + return torch.compile(func, backend="hpu_backend") + return func + + +def compile_func_on_cuda_or_cpu(func): + return torch.compile(func) + + +def compile_func(fun, device): + if "hpu" in str(device): + return compile_func_on_hpu(fun) ## use auto by default + else: + return compile_func_on_cuda_or_cpu(fun) + + +def is_numba_available(): # pragma: no cover + """Check if Numba is available.""" + try: + import numba + + return True + except ImportError: + return False + + +def _is_tbb_installed(): # pragma: no cover + import importlib.metadata + + try: + importlib.metadata.version("tbb") + return True + except importlib.metadata.PackageNotFoundError: + return False + + +def _is_tbb_configured(): # pragma: no cover + try: + from numba.np.ufunc.parallel import _check_tbb_version_compatible + + # check if TBB is present and compatible + _check_tbb_version_compatible() + + return True + except ImportError as e: + logger.warning_once(f"TBB not available: {e}") + return False + + +def is_tbb_available(): # pragma: no cover + """Check if TBB is available.""" + if not _is_tbb_installed(): + logger.warning_once("TBB is not installed, please install it with `pip install tbb`.") + return False + if not _is_tbb_configured(): + logger.warning_once( + ( + "TBB is installed but not configured correctly. \n" + "Please add the TBB library path to `LD_LIBRARY_PATH`, " + "for example: `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/`." + ) + ) + return False + return True + + +def can_pack_with_numba(): # pragma: no cover + """Check if Numba and TBB are available for packing. + + To pack tensor with Numba, both Numba and TBB are required, and TBB should be configured correctly. + """ + if not is_numba_available(): + logger.warning_once("Numba is not installed, please install it with `pip install numba`.") + return False + if not is_tbb_available(): + return False + return True + + +def get_fp_layer_names(model, fp_layers): + """Identifies and returns layers in the model to exclude from quantization. + + This function processes a comma-separated list of fully precision (FP) layers, + matches them to the names of layers in the model, and returns a list of such + layers to exclude from quantization. + + Args: + model (torch.nn.Module): The model whose layers will be inspected. + fp_layers (str): A comma-separated string of layer names to be excluded + from quantization. Whitespace is ignored in this string. + + Returns: + list: A list of layer names that match the specified FP layers or are + subcomponents of those layers. + """ + fp_layers = fp_layers.replace(" ", "").split(",") + all_layer_names = [] + for n, m in model.named_modules(): + if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D)): + all_layer_names.append(n) + not_to_quantized_layers = [] + + for fp_layer in fp_layers: + if fp_layer == "": + continue + if fp_layer in all_layer_names: + not_to_quantized_layers.append(fp_layer) + continue + if fp_layer[-1].isdigit(): + fp_layer = fp_layer + "." ##tricky setting + for name in all_layer_names: + if fp_layer in name: + not_to_quantized_layers.append(name) + + return not_to_quantized_layers + + +def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=None): + """Checks if a model is compatible with the AutoAWQ GEMM kernel. + + Args: + model: The model object to evaluate, typically a PyTorch model. + bits (int): The number of bits for quantization (must be 4 for compatibility). + group_size (int): The group size for quantization. + sym (bool): Whether symmetric quantization is used (not utilized in the current function logic). + layer_configs (dict, optional): A dictionary mapping layer names to configurations, where each + configuration can specify a custom number of bits for the layer. + + Returns: + tuple: A tuple containing: + - bool: `True` if the model is compatible, `False` otherwise. + - str: An error message describing why the model is incompatible, or an empty string if compatible. + """ + if bits != 4: + return False, f"AutoAWQ GEMM kernel only supports 4 bits" + for n, m in model.named_modules(): + if isinstance(m, transformers.modeling_utils.Conv1D): + return False, "AutoAWQ GEMM kernel does not support conv1d" + + layer_names = get_layer_names_in_block(model) + for layer_name in layer_names: + if layer_configs is not None and layer_name in layer_configs.keys() and layer_configs[layer_name].get("bits", + bits) > 8: + continue + + layer = get_module(model, layer_name) + if layer.in_features % group_size != 0: + return False, f"Layer {layer_name} in_features is not multiple of group_size {group_size}" + if layer.out_features % (32 // bits) != 0: + return False, f"Layer {layer_name} out_features is not multiple of 32 // bits" + + return True, "" + + +def get_device_and_parallelism(device): + from auto_round.utils import detect_device + devices = device.replace(" ", "").split(',') + if all(s.isdigit() for s in devices) and len(devices) > 1 and torch.cuda.is_available(): + device = "cuda" + parallelism = True + elif all(s.isdigit() for s in devices) and len(devices) > 1 and torch.xpu.is_available(): + device = "xpu" + parallelism = False + # pragma: no cover + elif device == "auto": + device = detect_device(device) + parallelism = True + else: + device = detect_device(device) + parallelism = False + return device, parallelism + + +def set_cuda_visible_devices(device): + devices = device.replace(" ", "").split(',') + if all(s.isdigit() for s in devices): + if "CUDA_VISIBLE_DEVICES" in os.environ: + current_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] + current_visible_devices = current_visible_devices.split(',') + indices = [int(device) for device in devices] + try: + pick_device = [current_visible_devices[i] for i in indices] + except: + raise ValueError( + "Invalid '--device' value: It must be smaller than the number of available devices." + " For example, with CUDA_VISIBLE_DEVICES=4,5, " + "--device 0,1 is valid, but --device 4,5 is not supported.") + visible_devices = ','.join(pick_device) + os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices + else: + os.environ["CUDA_VISIBLE_DEVICES"] = device + + +def is_debug_mode(): + """Checks if the Python interpreter is running in debug mode. + + Returns: + bool: True if debugging is enabled, False otherwise. + """ + return sys.gettrace() is not None or sys.flags.debug == 1 + + +def get_layer_features(layer): + """Extracts input and output feature dimensions for supported layers.""" + if isinstance(layer, torch.nn.Linear): + return layer.in_features, layer.out_features + elif isinstance(layer, transformers.pytorch_utils.Conv1D): # TODO: Verify correctness + return layer.weight.shape[0], layer.weight.shape[1] + return None, None # Unsupported layer type + + +def _gguf_args_check(args): + from auto_round.utils import logger + from auto_round.export.export_to_gguf.config import GGUF_CONFIG + + formats = args.format.lower().replace(' ', '').split(",") + formats = sorted(formats, key=lambda x: len(x)) + pattern = re.compile("q\d_k") + pre_dq_format = "" + for format in GGUF_CONFIG: + if format in formats: + try: + from auto_round.export.export_to_gguf.convert import Model + except: + raise ImportError( + f"Please use the latest gguf-py for {format}, you can use the following command to install it:\n" + "git clone https://github.com/ggml-org/llama.cpp.git && cd llama.cpp/gguf-py && pip install .") + sys.exit(-1) + if re.search(pattern, format): + if pre_dq_format and re.search(pattern, format).group() not in pre_dq_format: + logger.error(f"Cannot eport {pre_dq_format} and {format} at the same time.") + sys.exit(-1) + else: + pre_dq_format = format + + if os.path.isdir(args.model): + from pathlib import Path + from auto_round.export.export_to_gguf.convert import Model + hparams = Model.load_hparams(Path(args.model)) + model_architecture = hparams["architectures"][0] + try: + model_class = Model.from_model_architecture(model_architecture) + except NotImplementedError: + logger.error(f"Model {model_architecture} is not supported to export GGUF format") + sys.exit(1) + + if re.search(pattern, format) and ("hidden_size" in hparams and hparams["hidden_size"] % 256 != 0): + model_name = args.model.split('/') + model_name = model_name[-1] if model_name[-1] else model_name[-2] + hidden_size = hparams["hidden_size"] + logger.error( + f"Currently only support pure mode for format: {format}. " + f"{model_name} is not supported, cause hidden_size({hidden_size}) % 256 !=0") + sys.exit(-1) + + unsupport_list, reset_list = [], [] + gguf_config = GGUF_CONFIG[format] + for k, v in gguf_config.items(): + if getattr(args, k) != v: + unsupport_list.append(f"{k}={getattr(args, k)}") + reset_list.append(f"{k}={v}") + setattr(args, k, v) + if len(unsupport_list) > 0: + logger.error( + f"format {format} does not support for {', '.join(unsupport_list)}," + f" reset to {', '.join(reset_list)}.") + logger.info(f"export format {format}, sym = {not args.asym}, group_size = {args.group_size}") + + return args + + +def _to_model_dtype(model, model_dtype): + if model_dtype is not None: + try: + if model_dtype == "float16" or model_dtype == "fp16": + model = model.to(torch.float16) + elif model_dtype == "bfloat16" or model_dtype == "bfp16" or model_dtype == "bf16": + model = model.to(torch.bfloat16) + elif model_dtype == "float32" or model_dtype == "fp32": + model = model.to(torch.float32) + except: + logger.error("please use more device to fit the device or just use one device") + exit() + return model + + +def llm_load_model( + pretrained_model_name_or_path, + torch_dtype="auto", + use_auto_mapping=True, + trust_remote_code=True, + model_dtype=None, + device="cpu", + low_cpu_mem_mode=0, + low_cpu_mem_tmp_dir=None, + **kwargs): + from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM + + is_glm = bool(re.search("chatglm", pretrained_model_name_or_path.lower())) + low_cpu_mem_usage = False + + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + + model_cls = AutoModel if is_glm else AutoModelForCausalLM + + if low_cpu_mem_tmp_dir is None: + low_cpu_mem_tmp_dir = "low_cpu_mem_tmp" + if low_cpu_mem_mode == 2: + from auto_round.low_cpu_mem.utils import load_model_with_hooks + model = load_model_with_hooks( + pretrained_model_name_or_path, + model_cls, + device=device, + clean_weight=True, + saved_path=low_cpu_mem_tmp_dir, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code) + elif low_cpu_mem_mode == 1: + from auto_round.low_cpu_mem.utils import load_empty_model + low_cpu_mem_usage = True + model = load_empty_model( + pretrained_model_name_or_path, + model_cls, + device=device, + saved_path=low_cpu_mem_tmp_dir, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code) + else: + if _use_hpu_compile_mode(): + model = model_cls.from_pretrained( + pretrained_model_name_or_path, low_cpu_mem_usage=True, torch_dtype=torch_dtype, + attn_implementation="eager", + trust_remote_code=trust_remote_code, device_map="auto" if use_auto_mapping else None + ) + else: + model = model_cls.from_pretrained( + pretrained_model_name_or_path, low_cpu_mem_usage=True, torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, device_map="auto" if use_auto_mapping else None + ) + + model = model.eval() + model = _to_model_dtype(model, model_dtype) + + return model, tokenizer, low_cpu_mem_usage + + +def mllm_load_model( + pretrained_model_name_or_path, + torch_dtype="auto", + use_auto_mapping=True, + trust_remote_code=True, + model_dtype=None, + **kwargs): + import json + import transformers + from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM + from huggingface_hub import HfApi, HfFileSystem + + if os.path.isdir(pretrained_model_name_or_path): + config = json.load(open(os.path.join(pretrained_model_name_or_path, "config.json"))) + else: + hf_file = HfFileSystem() + config = json.load(hf_file.open(pretrained_model_name_or_path + "/config.json")) + + if "model_type" in config: + model_type = config["model_type"] + else: + model_type = None + + processor, image_processor = None, None + if "deepseek_vl_v2" == model_type: + from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM # pylint: disable=E0401 + processor = DeepseekVLV2Processor.from_pretrained(pretrained_model_name_or_path) + tokenizer = processor.tokenizer + model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + torch_dtype=torch_dtype, + device_map="auto" if use_auto_mapping else None) + else: + architectures = config["architectures"][0] + if architectures == "LlavaLlamaForCausalLM": + from llava.model.builder import load_pretrained_model # pylint: disable=E0401 + tokenizer, model, image_processor, _ = load_pretrained_model( + pretrained_model_name_or_path, + model_base=None, + model_name=pretrained_model_name_or_path, + torch_dtype=torch_dtype) + else: + if hasattr(transformers, architectures): + cls = getattr(transformers, architectures) + else: + cls = AutoModelForCausalLM + model = cls.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + torch_dtype=torch_dtype, + device_map="auto" if use_auto_mapping else None) + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + processor = AutoProcessor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + + model = model.eval() + model = _to_model_dtype(model, model_dtype) + + return model, processor, tokenizer, image_processor + + +def is_pure_text_model(model): + """verify on: phi-3.5, Mistral-Small-3.1, gemma-3, qwen2-vl, """ + if hasattr(model, "config") and hasattr(model.config, "vision_config"): + return False + if hasattr(model.__class__, "main_input_name") and model.__class__.main_input_name != "input_ids": + return False + for module in model.modules(): + if hasattr(module.__class__, "main_input_name") and module.__class__.main_input_name != "input_ids": + return False + if "vision" in str(module.__class__).lower(): + return False + if "image" in str(module.__class__).lower(): + return False + if "img" in str(module.__class__).lower(): + return False + return True + + +def reset_params(inputs): + """ + Resets specific input parameters to avoid saving the key-value cache during fine-tuning. + + Args: + inputs (dict): Dictionary of model inputs. + + Modifies: + inputs (dict): Sets "use_cache" to False if the key is present. + """ + if "use_cache" in inputs.keys(): # Not storing kv cache + inputs['use_cache'] = False + + +def check_skippable_keywords(key): + """ + Prints a reminder if a key is not stored during quantization fine-tuning. + """ + skippable_cache_keys = ("past_key_value",) + for cache_key in skippable_cache_keys: + if cache_key not in key: + return True + return False + + +def init_cache(positional_inputs, inputs): + """ + Initializes special model inputs by adding positional inputs if missing. + + Args: + positional_inputs (list): List of positional inputs to add to inputs. + inputs (dict): Dictionary of model inputs. + + Modifies: + inputs (dict): Adds "positional_inputs" key if not present. + """ + if "positional_inputs" not in inputs: # for chatglm Series + inputs["positional_inputs"] = [] + for idx, item in enumerate(positional_inputs): + inputs["positional_inputs"] = to_device(positional_inputs) + + +def get_shared_keys(model): + """ + Retrieves shared keys from the model's state dictionary. + + Args: + model (torch.nn.Module): The model to retrieve shared keys from. + + Returns: + tuple: tuple of shared keys. + """ + shared_keys = shared_cache_keys + shared_keys += SPECIAL_SHARED_CACHE_KEYS.get(model.__class__.__name__, ()) + return shared_keys + + +def get_model_dtype(model_dtype, default="auto"): + if model_dtype is None or model_dtype == "auto": + model_dtype = default + elif model_dtype in ["bf16", "bfloat16"]: + model_dtype = "bfloat16" + elif model_dtype in ["f16", "float16", "fp16"]: + model_dtype = "float16" + elif model_dtype in ["f32", "float32", "fp32"]: + model_dtype = "float32" + else: + logger.warning(f"Unable to identify model_dtype {model_dtype}, reset to default model_dtype {default}") + model_dtype = default + return model_dtype + + +def filter_quantization_config(quantization_config): + default_dict = {"amp": True, "batch_size": 8, "data_type": int, "dataset": "NeelNanda/pile-10k", + "enable_minmax_tuning": True, "enable_norm_bias_tuning": False, "enable_quanted_input": True, + "gradient_accumulate_steps": 1, "iters": 200, "low_gpu_mem_usage": False, "nsamples": 128, + "scale_dtype": "torch.float16", "seqlen": 2048} + iters = quantization_config.get("iters", 200) + + default_dict["lr"] = 1.0 / iters if iters > 0 else 5e-3 + default_dict["minmax_lr"] = default_dict["lr"] + + for key in default_dict: + if key in quantization_config and default_dict[key] == quantization_config[key]: + quantization_config.pop(key) + for k in list(quantization_config.keys()): + if quantization_config[k] is None: + quantization_config.pop(k) + + if quantization_config.get("act_bits", 16) >= 16: + quantization_config.pop("act_bits", None) + quantization_config.pop("act_data_type", None) + quantization_config.pop("act_dynamic", None) + quantization_config.pop("act_sym", None) + quantization_config.pop("act_group_size", None) + diff --git a/auto_round_diff/version.py b/auto_round_diff/version.py new file mode 100644 index 00000000..41f5016d --- /dev/null +++ b/auto_round_diff/version.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Intel® auto-round: An open-source Python library +supporting popular model weight only compression based on signround.""" + +__version__ = "0.5.1" diff --git a/auto_round_diff/wrapper_block.py b/auto_round_diff/wrapper_block.py new file mode 100644 index 00000000..69554b44 --- /dev/null +++ b/auto_round_diff/wrapper_block.py @@ -0,0 +1,398 @@ +import logging +from types import MethodType +import torch as th +from torch import einsum +import torch.nn as nn +from einops import rearrange, repeat + +from qdiff.quant_layer import QuantModule, UniformAffineQuantizer, StraightThrough +from ldm.modules.diffusionmodules.openaimodel import AttentionBlock, ResBlock, TimestepBlock, checkpoint +from ldm.modules.diffusionmodules.openaimodel import QKMatMul, SMVMatMul +from ldm.modules.attention import BasicTransformerBlock +from ldm.modules.attention import exists, default +from wrapper_layer import WrapperLinear +from ddim.models.diffusion import ResnetBlock, AttnBlock, nonlinearity + +logger = logging.getLogger("autoround") + +class BaseWrapperBlock(nn.Module): + """ + Base implementation of block structures for all networks. + """ + def __init__(self, act_quant_params: dict = {}): + super().__init__() + self.use_weight_quant = False + self.use_act_quant = False + # initialize quantizer + + self.act_quantizer = UniformAffineQuantizer(**act_quant_params) + + self.ignore_reconstruction = False + + def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): + # setting weight quantization here does not affect actual forward pass + self.use_weight_quant = weight_quant + self.use_act_quant = act_quant + for m in self.modules(): + if isinstance(m, WrapperLinear): + m.set_quant_state(weight_quant, act_quant) + + +class WapperResBlock(BaseWrapperBlock, TimestepBlock): + def __init__( + self, res: ResBlock, act_quant_params: dict = {}): + super().__init__(act_quant_params) + self.channels = res.channels + self.emb_channels = res.emb_channels + self.dropout = res.dropout + self.out_channels = res.out_channels + self.use_conv = res.use_conv + self.use_checkpoint = res.use_checkpoint + self.use_scale_shift_norm = res.use_scale_shift_norm + + self.in_layers = res.in_layers + + self.updown = res.updown + + self.h_upd = res.h_upd + self.x_upd = res.x_upd + + self.emb_layers = res.emb_layers + self.out_layers = res.out_layers + + self.skip_connection = res.skip_connection + + def forward(self, x, emb=None, split=0): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if split != 0 and self.skip_connection.split == 0: + return checkpoint( + self._forward, (x, emb, split), self.parameters(), self.use_checkpoint + ) + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb, split=0): + # print(f"x shape {x.shape} emb shape {emb.shape}") + if emb is None: + assert(len(x) == 2) + x, emb = x + assert x.shape[2] == x.shape[3] + + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + if split != 0: + return self.skip_connection(x, split=split) + h + return self.skip_connection(x) + h + + +class WapperQKMatMul(BaseWrapperBlock): + def __init__( + self, act_quant_params: dict = {}): + super().__init__(act_quant_params) + self.scale = None + self.use_act_quant = False + self.act_quantizer_q = UniformAffineQuantizer(**act_quant_params) + self.act_quantizer_k = UniformAffineQuantizer(**act_quant_params) + + def forward(self, q, k): + if self.use_act_quant: + quant_q = self.act_quantizer_q(q * self.scale) + quant_k = self.act_quantizer_k(k * self.scale) + weight = th.einsum( + "bct,bcs->bts", quant_q, quant_k + ) + else: + weight = th.einsum( + "bct,bcs->bts", q * self.scale, k * self.scale + ) + return weight + + def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): + self.use_act_quant = act_quant + + +class WapperSMVMatMul(BaseWrapperBlock): + def __init__( + self, act_quant_params: dict = {}, sm_abit=8): + super().__init__(act_quant_params) + self.use_act_quant = False + self.act_quantizer_v = UniformAffineQuantizer(**act_quant_params) + act_quant_params_w = act_quant_params.copy() + act_quant_params_w['n_bits'] = sm_abit + act_quant_params_w['symmetric'] = False + act_quant_params_w['always_zero'] = True + self.act_quantizer_w = UniformAffineQuantizer(**act_quant_params_w) + + def forward(self, weight, v): + if self.use_act_quant: + a = th.einsum("bts,bcs->bct", self.act_quantizer_w(weight), self.act_quantizer_v(v)) + else: + a = th.einsum("bts,bcs->bct", weight, v) + return a + + def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): + self.use_act_quant = act_quant + + +class WapperAttentionBlock(BaseWrapperBlock): + def __init__( + self, attn: AttentionBlock, act_quant_params: dict = {}): + super().__init__(act_quant_params) + self.channels = attn.channels + self.num_heads = attn.num_heads + self.use_checkpoint = attn.use_checkpoint + self.norm = attn.norm + self.qkv = attn.qkv + + self.attention = attn.attention + + self.proj_out = attn.proj_out + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def cross_attn_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + if self.use_act_quant: + quant_q = self.act_quantizer_q(q) + quant_k = self.act_quantizer_k(k) + sim = einsum('b i d, b j d -> b i j', quant_q, quant_k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -th.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + if self.use_act_quant: + out = einsum('b i j, b j d -> b i d', self.act_quantizer_w(attn), self.act_quantizer_v(v)) + else: + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class WapperBasicTransformerBlock(BaseWrapperBlock): + def __init__( + self, tran: BasicTransformerBlock, act_quant_params: dict = {}, + sm_abit: int = 8): + super().__init__(act_quant_params) + self.attn1 = tran.attn1 + self.ff = tran.ff + self.attn2 = tran.attn2 + + self.norm1 = tran.norm1 + self.norm2 = tran.norm2 + self.norm3 = tran.norm3 + self.checkpoint = tran.checkpoint + # self.checkpoint = False + + # logger.info(f"quant attn matmul") + self.attn1.act_quantizer_q = UniformAffineQuantizer(**act_quant_params) + self.attn1.act_quantizer_k = UniformAffineQuantizer(**act_quant_params) + self.attn1.act_quantizer_v = UniformAffineQuantizer(**act_quant_params) + + self.attn2.act_quantizer_q = UniformAffineQuantizer(**act_quant_params) + self.attn2.act_quantizer_k = UniformAffineQuantizer(**act_quant_params) + self.attn2.act_quantizer_v = UniformAffineQuantizer(**act_quant_params) + + act_quant_params_w = act_quant_params.copy() + act_quant_params_w['n_bits'] = sm_abit + act_quant_params_w['always_zero'] = True + self.attn1.act_quantizer_w = UniformAffineQuantizer(**act_quant_params_w) + self.attn2.act_quantizer_w = UniformAffineQuantizer(**act_quant_params_w) + + self.attn1.forward = MethodType(cross_attn_forward, self.attn1) + self.attn2.forward = MethodType(cross_attn_forward, self.attn2) + self.attn1.use_act_quant = False + self.attn2.use_act_quant = False + + def forward(self, x, context=None): + # print(f"x shape {x.shape} context shape {context.shape}") + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + if context is None: + assert(len(x) == 2) + x, context = x + + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): + self.attn1.use_act_quant = act_quant + self.attn2.use_act_quant = act_quant + + # setting weight quantization here does not affect actual forward pass + self.use_weight_quant = weight_quant + self.use_act_quant = act_quant + for m in self.modules(): + if isinstance(m, QuantModule): + m.set_quant_state(weight_quant, act_quant) + + +# the two classes below are for DDIM CIFAR +class WapperResnetBlock(BaseWrapperBlock): + def __init__( + self, res: ResnetBlock, act_quant_params: dict = {}): + super().__init__(act_quant_params) + self.in_channels = res.in_channels + self.out_channels = res.out_channels + self.use_conv_shortcut = res.use_conv_shortcut + + self.norm1 = res.norm1 + self.conv1 = res.conv1 + self.temb_proj = res.temb_proj + self.norm2 = res.norm2 + self.dropout = res.dropout + self.conv2 = res.conv2 + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = res.conv_shortcut + else: + self.nin_shortcut = res.nin_shortcut + + + def forward(self, x, temb=None, split=0): + if temb is None: + assert(len(x) == 2) + x, temb = x + + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x, split=split) + out = x + h + return out + + +class WapperAttnBlock(BaseWrapperBlock): + def __init__( + self, attn: AttnBlock, act_quant_params: dict = {}, sm_abit=8): + super().__init__(act_quant_params) + self.in_channels = attn.in_channels + + self.norm = attn.norm + self.q = attn.q + self.k = attn.k + self.v = attn.v + self.proj_out = attn.proj_out + + self.act_quantizer_q = UniformAffineQuantizer(**act_quant_params) + self.act_quantizer_k = UniformAffineQuantizer(**act_quant_params) + self.act_quantizer_v = UniformAffineQuantizer(**act_quant_params) + + act_quant_params_w = act_quant_params.copy() + act_quant_params_w['n_bits'] = sm_abit + self.act_quantizer_w = UniformAffineQuantizer(**act_quant_params_w) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h*w) # b,c,hw + if self.use_act_quant: + q = self.act_quantizer_q(q) + k = self.act_quantizer_k(k) + w_ = th.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + if self.use_act_quant: + v = self.act_quantizer_v(v) + w_ = self.act_quantizer_w(w_) + h_ = th.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + out = x + h_ + return out + + +def get_specials(quant_act=False): + specials = { + ResBlock: WapperResBlock, + BasicTransformerBlock: WapperBasicTransformerBlock, + ResnetBlock: WapperResnetBlock, + AttnBlock: WapperAttnBlock, + } + if quant_act: + specials[QKMatMul] = WapperQKMatMul + specials[SMVMatMul] = WapperSMVMatMul + else: + specials[AttentionBlock] = WapperAttentionBlock + return specials \ No newline at end of file diff --git a/auto_round_diff/wrapper_layer.py b/auto_round_diff/wrapper_layer.py new file mode 100644 index 00000000..9f62bdf7 --- /dev/null +++ b/auto_round_diff/wrapper_layer.py @@ -0,0 +1,530 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import copy +import torch +from torch.functional import F +import transformers +from auto_round_diff.data_type import get_quant_func, get_static_quant_func +from .utils import ( + check_to_quantized, + get_scale_shape, + set_module, + logger +) +import logging +logger = logging.getLogger("autoround") + +def reshape_and_pad_tensor(v, group_size=-1): + """Reshapes the tensor based on the group size. + + Args: + v (torch.Tensor): The input tensor to be reshaped. + group_size (int, optional): The number of elements to group together. + + Returns: + torch.Tensor: The reshaped tensor. If padding is applied, the padded tensor is returned. + """ + if group_size == -1 or v.shape[1] < group_size: + return v + if v.shape[1] % group_size == 0: + v = v.reshape(-1, group_size) + else: + pad_len = (v.shape[1] + group_size - 1) // group_size * group_size - v.shape[1] + v = torch.nn.functional.pad(v, (0, pad_len)) + v = v.reshape(-1, group_size) + return v + + +class WrapperLinear(torch.nn.Module): + """A wrapper for linear/conv1d layers to enable quantization and tuning. + + This module wraps an existing linear or conv1d layer and provides additional functionality + for quantization, parameter tuning, and activation/bias normalization. + + Args: + orig_layer (torch.nn.Module): The original layer to be wrapped (linear or conv1d). + enable_minmax_tuning (bool): Whether to enable min-max scale tuning. + enable_norm_bias_tuning (bool): Whether to enable normalization and tuning of the bias term. + device (str): Device on which to run computations (e.g., 'cpu' or 'cuda'). + """ + + def __init__(self, orig_layer, keys, device='cpu', **kwargs): + """Initializes the WrapperLinear module. + + Args: + orig_layer (torch.nn.Module): The original layer to wrap. + enable_minmax_tuning (bool): Whether to enable min-max scale tuning. + enable_norm_bias_tuning (bool): Whether to enable normalization and tuning for the bias term. + device (str): The computation device, such as 'cpu' or 'cuda'. + """ + super(WrapperLinear, self).__init__() + self.keys = keys + self.orig_layer = orig_layer + self.split = 0 + self.output_device = device + self.enable_act_quant = self.orig_layer.quant_act + self.q_scale_thresh = 1e-5 + self._init_quant_params_and_func() + self.use_weight_quant = False + self.use_act_quant = False + + if isinstance(self.orig_layer, torch.nn.Conv2d): + self.orig_kwargs = dict(stride=self.orig_layer.stride, padding=self.orig_layer.padding, + dilation=self.orig_layer.dilation, groups=self.orig_layer.groups) + self.orig_forward = F.conv2d + elif isinstance(self.orig_layer, torch.nn.Conv1d): + self.orig_kwargs = dict(stride=self.orig_layer.stride, padding=self.orig_layer.padding, + dilation=self.orig_layer.dilation, groups=self.orig_layer.groups) + self.orig_forward = F.conv1d + else: + self.orig_kwargs = dict() + self.orig_forward = F.linear + + def _init_quant_params_and_func(self): + self.w_quant_params, self.act_quant_params = {"inited": False, "scale": None, "zp": None, }, {"inited": False, "scale": None, "zp": None} + orig_layer = self.orig_layer + for key in self.keys: + if hasattr(self.orig_layer, key): + if bool(re.search(r'w_|weight_', key)): + new_key = re.sub(r'w_|weight_', '', key) + self.w_quant_params[new_key] = getattr(orig_layer, key) + elif bool(re.search(r'a_|act_', key)): + new_key = re.sub(r'a_|act_', '', key) + self.act_quant_params[new_key] = getattr(orig_layer, key) + + orig_weight = getattr(orig_layer, "get_weight", lambda: orig_layer.weight)() + if isinstance(self.orig_layer, torch.nn.Conv1d): + orig_weight = orig_weight.t() + + self.weight_quant_func, self.data_type_w = get_static_quant_func(orig_layer.data_type_w, orig_layer.weight_bits, + orig_layer.sym_w) + + if self.enable_act_quant: + self.act_quant_func, self.data_type_act = get_static_quant_func(orig_layer.data_type_act, + orig_layer.act_bits, + orig_layer.sym_act) + + def _qdq_weight(self): + """Quantizes and dequantizes weights with tuning parameters. + + Args: + value (torch.Tensor): Value added for rounding for tuning. + min_scale (torch.Tensor): Minimum scale for the min value of quantization. + max_scale (torch.Tensor): Maximum scale for the max value of quantization. + + Returns: + tuple: Quantized weight, scale, and zero point. + """ + weight = self.orig_layer.weight + if isinstance(self.orig_layer, torch.nn.Conv1d): + weight = weight.t() + + if self.split != 0: + weight_0, self.w_quant_params["scale"], self.w_quant_params["zp"], self.w_quant_params["inited"] = self.weight_quant_func(weight[:, :self.split, ...], **self.w_quant_params) + weight_1, self.w_quant_params_0["scale"], self.w_quant_params_0["zp"], self.w_quant_params_0["inited"] = self.weight_quant_func(weight[:, self.split:, ...], **self.w_quant_params_0) + weight_q = torch.cat([weight_0, weight_1], dim=1) + else: + weight_q, self.w_quant_params["scale"], self.w_quant_params["zp"], self.w_quant_params["inited"] = self.weight_quant_func(weight, **self.w_quant_params) + + weight_q = weight_q.to(weight.dtype) + + if isinstance(self.orig_layer, torch.nn.Conv1d): + weight_q = weight_q.t() + return weight_q + + def _qdq_act(self, x): + """Quantizes and dequantizes activations. + + Args: + x (torch.Tensor): Input activations. + act_max_scale (torch.Tensor): Maximum scale for the act_max + act_max (torch.Tensor, optional): Maximum value for activation quantization. Defaults to None. + + Returns: + tuple: Quantized activation, scale, and zero point. + """ + if self.split != 0: + x, scale, zp = self.act_quant_func(x[:, :self.split, :, :], **self.act_quant_params) + x_0, scale_0, zp_0 = self.act_quant_func(x[:, self.split:, :, :], **self.act_quant_params_0) + x = torch.cat([x, x_0], dim=1) + + self.act_quant_params['scale'], self.act_quant_params['zp'] = scale, zp + self.act_quant_params0['zp'], self.act_quant_params0['zp'] = scale_0, zp_0 + else: + x, scale, zp = self.act_quant_func(x, **self.act_quant_params) + self.act_quant_params['scale'], self.act_quant_params['zp'] = zp + + return x, scale, zp, x_0, scale_0, zp_0 + + def _qdq_bias(self, bias, bias_v): + """Quantizes and dequantizes bias. + + Args: + bias (torch.Tensor): Bias tensor to be quantized. + bias_v (torch.Tensor): Value added for rounding for tuning. + + Returns: + tuple: Quantized bias, scale, and zero point. + """ + bias_bits = 4 ## hard code + bias_group_size = -1 + bias, scale, zp = self.bias_quant_func(bias, bits=bias_bits, group_size=bias_group_size, v=bias_v, + q_scale_thresh=self.q_scale_thresh) + return bias, scale, zp + + def unwrapper(self, best_params): + """Restores the original layer by applying the best tuning parameters. + + Args: + best_params (dict): Dictionary containing the best tuning parameters. + + Returns: + torch.nn.Module: The unwrapped and restored original layer. + """ + best_params = best_params or {} + v = best_params.get('value', torch.tensor(0.0)).to(self.device) + min_scale = best_params.get('min_scale', torch.tensor(1.0)).to(self.device) + max_scale = best_params.get('max_scale', torch.tensor(1.0)).to(self.device) + + if self.orig_layer.weight.device.type == 'meta': + self.orig_layer.to(self.device) + ##unwrapper weight + qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) + + self.orig_layer.weight.data.copy_(qdq_weight) + self.orig_layer.weight.grad = None + + shape = qdq_weight.shape + if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): + shape = qdq_weight.t().shape + + def _set_dict_attr(attr_dict, attr_name): + for key in attr_dict.keys(): + if key == attr_name: + setattr(self.orig_layer, attr_name, attr_dict[key].reshape(shape[0], -1).to("cpu")) + else: + name = "w_" + key + setattr(self.orig_layer, name, attr_dict[key].to("cpu")) + + if isinstance(scale, dict): + _set_dict_attr(scale, "scale") + else: + self.orig_layer.scale = scale.reshape(shape[0], -1).to("cpu") + + if zp is not None: + if isinstance(zp, dict): + _set_dict_attr(zp, "zp") + else: + zp = zp.reshape(shape[0], -1) + self.orig_layer.zp = zp.to("cpu") if zp is not None else None + else: + self.orig_layer.zp = None + + ##unwrapper bias + if self.enable_norm_bias_tuning and "bias_v" in best_params.keys(): ##fake quant + bias_v = best_params["bias_v"].to(self.device) + bias = self.orig_layer.bias + if bias is not None and bias.device.type == 'meta': + bias = self.orig_layer.get_bias().to(self.device) + bias, _, _ = self._qdq_bias(bias, bias_v) + self.orig_layer.bias.grad = None + self.orig_layer.bias.data.copy_(bias) + + if hasattr(self.orig_layer, 'update'): + self.orig_layer.update() + self.orig_layer.to('meta') + + ##unwrapper act + if self.enable_act_quant: + if not self.orig_layer.act_dynamic: + act_max_scale = best_params.get('act_max_scale', torch.tensor(1.0)).to(self.device) + act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None + tmp_shape = (1) + if self.orig_layer.act_group_size > 1: + tmp_shape = (1, self.orig_layer.act_group_size) + _, act_scale, _ = self._qdq_act(torch.zeros(tmp_shape).to(self.device), + act_max_scale=self.act_max_scale, act_max=act_max) + self.orig_layer.act_max = torch.tensor(self.orig_layer.act_max * act_max_scale.item()).to("cpu") + self.orig_layer.act_scale = act_scale.to("cpu") + + self.orig_layer.q_scale_thresh = self.q_scale_thresh + self.orig_layer.data_type = self.data_type + + self.orig_layer.act_data_type = self.act_data_type + self.orig_layer.act_quant_func = self.act_quant_func + wrapper_layer = WrapperWALayer(self.orig_layer) + return wrapper_layer + + return self.orig_layer + + def set_split(self): + self.w_quant_params_0 = copy.deepcopy(self.w_quant_params) + self.act_quant_params_0 = copy.deepcopy(self.act_quant_params) + + def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): + self.use_weight_quant = weight_quant + self.use_act_quant = act_quant + + def set_running_stat(self, running_stat_a: bool): + self.act_quant_params["running_stat_a"] = running_stat_a + if self.split != 0: + self.act_quant_params_0["running_stat_a"] = running_stat_a + + def forward(self, x, split=0): + """Executes the forward pass with quantized weights and optional bias/activation quantization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the wrapped layer. + """ + x = x.to(self.output_device) + + if split != 0 and self.split != 0: + assert(split == self.split) + elif split != 0: + logger.info(f"split at {split}!") + self.split = split + self.set_split() + + if self.enable_act_quant and self.use_act_quant: + x, _, _, _, _, _ = self._qdq_act(x) + if self.use_weight_quant: + weight = self._qdq_weight() + else: + weight = self.org_weight + output = self.orig_forward(x, weight, self.orig_layer.bias, **self.orig_kwargs).to(self.output_device) + + return output + + +class WrapperWALayer(torch.nn.Module): + def __init__(self, orig_layer): + super(WrapperWALayer, self).__init__() + self.orig_layer = orig_layer + self.act_quant_func = self.orig_layer.act_quant_func + + def forward(self, x): + act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None + x, _, _ = self.orig_layer.act_quant_func(x, bits=self.orig_layer.act_bits, + group_size=self.orig_layer.group_size, + scale_dtype=self.orig_layer.scale_dtype, + q_scale_thresh=self.orig_layer.q_scale_thresh, + data_type=self.orig_layer.act_data_type, + tensor_max=act_max) + return self.orig_layer.forward(x) + + +class WrapperLayerNorm(torch.nn.Module): + """A wrapper for layer normalization with quantized weights. + + This class wraps a given layer normalization module and applies quantization without round + to its weights. The quantization is parameterized by the number of bits and + an optional group size. + """ + + def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"): + super(WrapperLayerNorm, self).__init__() + self.orig_layer = orig_layer + self.bits = bit + self.group_size = group_size + self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device + self.output_device = device + weight_dtype = torch.float32 + self.q_scale_thresh = 1e-5 + self.v = torch.nn.Parameter( + reshape_and_pad_tensor( + torch.zeros(self.orig_layer.weight.shape, device=self.device, dtype=weight_dtype), + self.group_size), + requires_grad=True) + self.params = {"v": self.v} + from auto_round.data_type.int import quant_tensor_asym_wo_round + self.quant_func = quant_tensor_asym_wo_round + + def unwrapper(self, best_params): + if best_params is None: + return self.orig_layer + v = best_params['v'] + weight_q, _, _ = self.quant_func(self.orig_layer.weight, self.bits, self.group_size, + v, q_scale_thresh=self.q_scale_thresh) + self.orig_layer.q_scale_thresh = self.q_scale_thresh + self.orig_layer.weight.data.copy_(weight_q) + return self.orig_layer + + def forward(self, input): + input = input.to(self.device) + weight_q, _, _ = self.quant_func(self.orig_layer.weight, self.bits, self.group_size, + self.v, q_scale_thresh=self.q_scale_thresh) + import torch.nn.functional as F + return F.layer_norm( + input, self.orig_layer.normalized_shape, weight_q, self.orig_layer.bias, self.orig_layer.eps).to( + self.output_device) + + +class WrapperLlamaNorm(torch.nn.Module): + """A wrapper for Llama normalization in HF with fake quantized weights without rounding. + + This class wraps a given layer normalization module and applies quantization without rounding + to its weights. The quantization is parameterized by the number of bits and + an optional group size. + """ + + def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"): + super(WrapperLlamaNorm, self).__init__() + self.orig_layer = orig_layer + self.bits = bit + self.group_size = group_size + self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device + self.output_device = device + weight_dtype = torch.float32 + self.q_scale_thresh = 1e-5 + self.v = torch.nn.Parameter( + reshape_and_pad_tensor( + torch.zeros(self.orig_layer.weight.shape, device=self.device, dtype=weight_dtype), + self.group_size), + requires_grad=True) + self.params = {"v": self.v} + from auto_round.data_type.int import quant_tensor_asym_wo_round + self.quant_func = quant_tensor_asym_wo_round + + def unwrapper(self, best_params): + if best_params is None: + return self.orig_layer + v = best_params['v'] + weight_q, _, _ = self.quant_func(self.orig_layer.weight, self.bits, self.group_size, + v, q_scale_thresh=self.q_scale_thresh) + self.orig_layer.q_scale_thresh = self.q_scale_thresh + self.orig_layer.weight.data.copy_(weight_q) + return self.orig_layer + + def forward(self, hidden_states): + hidden_states = hidden_states.to(self.device) + weight_q, _, _ = self.quant_func(self.orig_layer.weight, self.bits, self.group_size, + self.v, q_scale_thresh=self.q_scale_thresh) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.orig_layer.variance_epsilon) + return (weight_q * hidden_states.to(input_dtype)).to(self.output_device) + + +norm_mapping = {} +norm_mapping["LayerNorm"] = WrapperLayerNorm +norm_mapping["LlamaRMSNorm"] = WrapperLlamaNorm +norm_mapping["Qwen2RMSNorm"] = WrapperLlamaNorm +norm_mapping["Phi3RMSNorm"] = WrapperLlamaNorm +norm_mapping["MistralRMSNorm"] = WrapperLlamaNorm + + +class WrapperMultiblock(torch.nn.Module): + """A wrapper for a list of modules to be act as a single block. + + Args: + module_list: The list of modules to wrap. + """ + + def __init__(self, module_list): + super(WrapperMultiblock, self).__init__() + self.layers = torch.nn.ModuleList(module_list) + + def forward(self, x, **kwargs): + hidden_states = x + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer(hidden_states, **kwargs) + hidden_states = layer_outputs + if isinstance(hidden_states, tuple) or isinstance(hidden_states, list): + hidden_states = layer_outputs[0] + return hidden_states + + +def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device='cpu', **kwargs): + """Wraps the layers in the given block with a custom Wrapper module. + + Args: + block: The input block containing linear and conv1d layers to be wrapped. + enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled. + + Returns: + list: A list of names of the wrapped layers and unwrapped layers. + """ + quantized_layers = [] + unquantized_layers = [] + for n, m in block.named_modules(): + if isinstance(m, (torch.nn.Linear, transformers.modeling_utils.Conv1D)): + if not check_to_quantized(m): + unquantized_layers.append(n) + continue + new_m = WrapperLinear(m, enable_minmax_tuning=enable_minmax_tuning, + enable_norm_bias_tuning=enable_norm_bias_tuning, device=device, + **kwargs, + ) + set_module(block, n, new_m) + quantized_layers.append(n) + + if enable_norm_bias_tuning: + if "norm" in m.__class__.__name__.lower(): + if m.__class__.__name__ in norm_mapping.keys(): + wrapper_layer_class = norm_mapping[m.__class__.__name__] + new_m = wrapper_layer_class(m, device=device) + setattr(block, n, new_m) + elif "RMSNorm" in m.__class__.__name__: + logger.warning_once( + f"use LlamaRMSNorm to wrap {m.__class__.__name__}, please check the correctness yourself") + wrapper_layer_class = norm_mapping["LlamaRMSNorm"] + new_m = wrapper_layer_class(m, device=device) + setattr(block, n, new_m) + else: + logger.warning_once(f"{m.__class__.__name__} is not supported") + + return quantized_layers, unquantized_layers + + +@torch.no_grad() +def unwrapper_layer(model, layer, layer_name, best_params): + """Unwraps the WrapperLinear and WrapperTransformerConv1d modules in the given block. + + Args: + block: The input block containing wrapped modules to be unwrapped. + vs: A dictionary of scaling parameters for the wrapped modules. + min_scales: A dictionary of minimum scaling values for the wrapped modules. + max_scales: A dictionary of maximum scaling values for the wrapped modules. + """ + + if hasattr(layer, "orig_layer"): + orig_layer = layer.unwrapper(best_params) + orig_layer = orig_layer.to("cpu") + set_module(model, layer_name, orig_layer) + + +@torch.no_grad() +def unwrapper_block(block, best_params): + """Unwraps the WrapperLinear and WrapperTransformerConv1d modules in the given block. + + Args: + block: The input block containing wrapped modules to be unwrapped. + vs: A dictionary of scaling parameters for the wrapped modules. + min_scales: A dictionary of minimum scaling values for the wrapped modules. + max_scales: A dictionary of maximum scaling values for the wrapped modules. + """ + for n, m in block.named_modules(): + if hasattr(m, "orig_layer"): + if n in best_params.keys(): + best_param = best_params[n] + else: + best_param = None + orig_layer = m.unwrapper(best_param) + set_module(block, n, orig_layer) diff --git a/configs/autoencoder/autoencoder_kl_16x16x16.yaml b/configs/autoencoder/autoencoder_kl_16x16x16.yaml new file mode 100644 index 00000000..5f1d10ec --- /dev/null +++ b/configs/autoencoder/autoencoder_kl_16x16x16.yaml @@ -0,0 +1,54 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 16 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [16] + dropout: 0.0 + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 12 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 2 diff --git a/configs/autoencoder/autoencoder_kl_32x32x4.yaml b/configs/autoencoder/autoencoder_kl_32x32x4.yaml new file mode 100644 index 00000000..ab8b36fe --- /dev/null +++ b/configs/autoencoder/autoencoder_kl_32x32x4.yaml @@ -0,0 +1,53 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 4 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 12 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 2 diff --git a/configs/autoencoder/autoencoder_kl_64x64x3.yaml b/configs/autoencoder/autoencoder_kl_64x64x3.yaml new file mode 100644 index 00000000..5e3db5c4 --- /dev/null +++ b/configs/autoencoder/autoencoder_kl_64x64x3.yaml @@ -0,0 +1,54 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 3 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 12 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 2 diff --git a/configs/autoencoder/autoencoder_kl_8x8x64.yaml b/configs/autoencoder/autoencoder_kl_8x8x64.yaml new file mode 100644 index 00000000..5ccd09d3 --- /dev/null +++ b/configs/autoencoder/autoencoder_kl_8x8x64.yaml @@ -0,0 +1,53 @@ +model: + base_learning_rate: 4.5e-6 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: "val/rec_loss" + embed_dim: 64 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + + ddconfig: + double_z: True + z_channels: 64 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [16,8] + dropout: 0.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 12 + wrap: True + train: + target: ldm.data.imagenet.ImageNetSRTrain + params: + size: 256 + degradation: pil_nearest + validation: + target: ldm.data.imagenet.ImageNetSRValidation + params: + size: 256 + degradation: pil_nearest + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + benchmark: True + accumulate_grad_batches: 2 diff --git a/configs/cifar10.yml b/configs/cifar10.yml new file mode 100644 index 00000000..0e48f55f --- /dev/null +++ b/configs/cifar10.yml @@ -0,0 +1,50 @@ +data: + dataset: "CIFAR10" + image_size: 32 + channels: 3 + logit_transform: false + uniform_dequantization: false + gaussian_dequantization: false + random_flip: true + rescaled: true + num_workers: 4 + +model: + type: "simple" + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 2] + num_res_blocks: 2 + attn_resolutions: [16, ] + dropout: 0.1 + var_type: fixedlarge + ema_rate: 0.9999 + ema: True + resamp_with_conv: True + +diffusion: + beta_schedule: linear + beta_start: 0.0001 + beta_end: 0.02 + num_diffusion_timesteps: 1000 + +training: + batch_size: 128 + n_epochs: 10000 + n_iters: 5000000 + snapshot_freq: 5000 + validation_freq: 2000 + +sampling: + batch_size: 64 + last_only: True + +optim: + weight_decay: 0.000 + optimizer: "Adam" + lr: 0.0002 + beta1: 0.9 + amsgrad: false + eps: 0.00000001 + grad_clip: 1.0 diff --git a/configs/latent-diffusion/celebahq-ldm-vq-4.yaml b/configs/latent-diffusion/celebahq-ldm-vq-4.yaml new file mode 100644 index 00000000..89b3df4f --- /dev/null +++ b/configs/latent-diffusion/celebahq-ldm-vq-4.yaml @@ -0,0 +1,86 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ckpt_path: models/first_stage_models/vq-f4/model.ckpt + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: taming.data.faceshq.CelebAHQTrain + params: + size: 256 + validation: + target: taming.data.faceshq.CelebAHQValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/cin-ldm-vq-f8.yaml b/configs/latent-diffusion/cin-ldm-vq-f8.yaml new file mode 100644 index 00000000..b8cd9e2e --- /dev/null +++ b/configs/latent-diffusion/cin-ldm-vq-f8.yaml @@ -0,0 +1,98 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 32 + channels: 4 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 256 + attention_resolutions: + #note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 32 for f8 + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 512 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 4 + n_embed: 16384 + ckpt_path: configs/first_stage_models/vq-f8/model.yaml + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 512 + key: class_label +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + num_workers: 12 + wrap: false + train: + target: ldm.data.imagenet.ImageNetTrain + params: + config: + size: 256 + validation: + target: ldm.data.imagenet.ImageNetValidation + params: + config: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/cin256-v2.yaml b/configs/latent-diffusion/cin256-v2.yaml new file mode 100644 index 00000000..b7c1aa24 --- /dev/null +++ b/configs/latent-diffusion/cin256-v2.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 0.0001 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 64 + channels: 3 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 192 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 5 + num_heads: 1 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 512 + + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.ClassEmbedder + params: + n_classes: 1001 + embed_dim: 512 + key: class_label diff --git a/configs/latent-diffusion/ffhq-ldm-vq-4.yaml b/configs/latent-diffusion/ffhq-ldm-vq-4.yaml new file mode 100644 index 00000000..1899e30f --- /dev/null +++ b/configs/latent-diffusion/ffhq-ldm-vq-4.yaml @@ -0,0 +1,85 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ckpt_path: configs/first_stage_models/vq-f4/model.yaml + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 42 + num_workers: 5 + wrap: false + train: + target: taming.data.faceshq.FFHQTrain + params: + size: 256 + validation: + target: taming.data.faceshq.FFHQValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml b/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml new file mode 100644 index 00000000..c4ca66c1 --- /dev/null +++ b/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml @@ -0,0 +1,85 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + ckpt_path: configs/first_stage_models/vq-f4/model.yaml + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: ldm.data.lsun.LSUNBedroomsTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNBedroomsValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml b/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml new file mode 100644 index 00000000..18dc8c2d --- /dev/null +++ b/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml @@ -0,0 +1,91 @@ +model: + base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0155 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + loss_type: l1 + first_stage_key: "image" + cond_stage_key: "image" + image_size: 32 + channels: 4 + cond_stage_trainable: False + concat_mode: False + scale_by_std: True + monitor: 'val/loss_simple_ema' + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [10000] + cycle_lengths: [10000000000000] + f_start: [1.e-6] + f_max: [1.] + f_min: [ 1.] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 192 + attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 + num_res_blocks: 2 + channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 + num_heads: 8 + use_scale_shift_norm: True + resblock_updown: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: "val/rec_loss" + ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" + ddconfig: + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: "__is_unconditional__" + +data: + target: main.DataModuleFromConfig + params: + batch_size: 96 + num_workers: 5 + wrap: False + train: + target: ldm.data.lsun.LSUNChurchesTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNChurchesValidation + params: + size: 256 + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/txt2img-1p4B-eval.yaml b/configs/latent-diffusion/txt2img-1p4B-eval.yaml new file mode 100644 index 00000000..8e331cbf --- /dev/null +++ b/configs/latent-diffusion/txt2img-1p4B-eval.yaml @@ -0,0 +1,71 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 32 + channels: 4 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_heads: 8 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 1280 + use_checkpoint: true + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.BERTEmbedder + params: + n_embed: 1280 + n_layer: 32 diff --git a/configs/retrieval-augmented-diffusion/768x768.yaml b/configs/retrieval-augmented-diffusion/768x768.yaml new file mode 100644 index 00000000..b51b1d83 --- /dev/null +++ b/configs/retrieval-augmented-diffusion/768x768.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 0.0001 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.015 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: jpg + cond_stage_key: nix + image_size: 48 + channels: 16 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_by_std: false + scale_factor: 0.22765929 + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 48 + in_channels: 16 + out_channels: 16 + model_channels: 448 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + use_scale_shift_norm: false + resblock_updown: false + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 768 + use_checkpoint: true + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 16 + ddconfig: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: torch.nn.Identity \ No newline at end of file diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml new file mode 100644 index 00000000..d4effe56 --- /dev/null +++ b/configs/stable-diffusion/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/configs/stable-diffusion/v1-inpainting-inference.yaml b/configs/stable-diffusion/v1-inpainting-inference.yaml new file mode 100644 index 00000000..f9eec37d --- /dev/null +++ b/configs/stable-diffusion/v1-inpainting-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important + monitor: val/loss_simple_ema + scale_factor: 0.18215 + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 # 4 data + 4 downscaled image + 1 mask + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/fp8_res_ada.pth b/fp8_res_ada.pth new file mode 100644 index 00000000..0b4bf4a6 Binary files /dev/null and b/fp8_res_ada.pth differ diff --git a/ldm/__init__.py b/ldm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/globalvar.py b/ldm/globalvar.py new file mode 100644 index 00000000..6b6d2d67 --- /dev/null +++ b/ldm/globalvar.py @@ -0,0 +1,19 @@ +# collect diffusion input for calibration +global diffusion_input_list +diffusion_input_list = [] + +def appendInput(value): + diffusion_input_list.append(value) + +def getInputList(): + return diffusion_input_list + + +## collect quantization error for correction +global data_error_t_list +data_error_t_list = [] ## storing model out['mean'] and quantization error and its step t in format (data, error, t) +def append(value): + data_error_t_list.append(value) + +def getList(): + return data_error_t_list \ No newline at end of file diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py new file mode 100644 index 00000000..be39da9c --- /dev/null +++ b/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/ldm/models/__init__.py b/ldm/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 00000000..6a9c4f45 --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py new file mode 100644 index 00000000..67e98b9d --- /dev/null +++ b/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py new file mode 100644 index 00000000..475626a9 --- /dev/null +++ b/ldm/models/diffusion/ddim.py @@ -0,0 +1,257 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img.to('cpu')], 'pred_x0': [img.to('cpu')], 'ts': []} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img.to('cpu')) + intermediates['pred_x0'].append(pred_x0.to('cpu')) + intermediates['ts'].append(ts.to('cpu')) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py new file mode 100644 index 00000000..a668ed6b --- /dev/null +++ b/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1527 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None, restore=True): + # print("==================test ema==================") + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema and restore: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + else: + print(f"{context}: Keep using EMA weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is expected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left positions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo can't deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + # import ipdb; ipdb.set_trace() + # import ldm.globalvar as globalvar + # globalvar.appendInput((x_noisy, t, cond['c_crossattn'][0])) + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs + + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + finetune_keys=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = ( + rearrange(batch[ck], "b h w c -> b c h w") + .to(memory_format=torch.contiguous_format) + .float() + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 00000000..7427f38c --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 00000000..8c0bb484 --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1184 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advice for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 00000000..2c42d6f9 --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,82 @@ +"""SAMPLING ONLY.""" + +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py new file mode 100644 index 00000000..8e764784 --- /dev/null +++ b/ldm/models/diffusion/plms.py @@ -0,0 +1,240 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img.to('cpu')], 'pred_x0': [img.to('cpu')], 'ts': [], 'cond': [], 'uncond': []} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img.to('cpu')) + intermediates['pred_x0'].append(pred_x0.to('cpu')) + intermediates['ts'].append(ts.to('cpu')) + intermediates['cond'].append(cond.to('cpu')) + intermediates['uncond'].append(unconditional_conditioning.to('cpu')) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + # import ipdb; ipdb.set_trace() + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3rd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/ldm/modules/__init__.py b/ldm/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 00000000..0c8003f4 --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,287 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.qk_matmul = CrossQKMatMul(self.scale) + self.smv_matmul = CrossSMVMatMul() + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = self.qk_matmul(q, k) + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + # out = einsum('b i j, b j d -> b i d', attn, v) + out = self.smv_matmul(attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class CrossQKMatMul(nn.Module): + + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, q, k): + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + return sim + + +class CrossSMVMatMul(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, attn, v): + out = einsum('b i j, b j d -> b i d', attn, v) + return out + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 00000000..ba1637a6 --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,838 @@ +# pytorch_diffusion + derived encoder decoder +import logging +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + +logger = logging.getLogger(__name__) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + logger.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + logger.info("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + logger.info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 00000000..3a5c3987 --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1001 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, split=0): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb, split=split) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb, split=0): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb, split), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb, split=0): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + if split>0: + return self.skip_connection(x,split=split) + h + else: + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKMatMul(nn.Module): + + def __init__(self): + super().__init__() + self.scale = None + + def forward(self, q, k): + weight = th.einsum( + "bct,bcs->bts", q * self.scale, k * self.scale + ) + return weight + + +class SMVMatMul(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, weight, v): + a = th.einsum("bts,bcs->bct", weight, v) + return a + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + self.qkv_matmul = QKMatMul() + self.smv_matmul = SMVMatMul() + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + # weight = th.einsum( + # "bct,bcs->bts", q * scale, k * scale + # ) # More stable with f16 than dividing afterwards + if self.qkv_matmul.scale is not None: + assert self.qkv_matmul.scale == scale + else: + self.qkv_matmul.scale = scale + weight = self.qkv_matmul(q, k) + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + # a = th.einsum("bts,bcs->bct", weight, v) + a = self.smv_matmul(weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.split = False + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + # import ipdb; ipdb.set_trace() + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + if self.split: + split = h.shape[1] + else: + split = 0 + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, split=split) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 00000000..a952e6c4 --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,267 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 00000000..f2b8ef90 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py new file mode 100644 index 00000000..7d49592c --- /dev/null +++ b/ldm/modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + # print(key) + # try: + # print(self.m_name2s_name[key]) + # except: + # print(self.m_name2s_name.keys()) + shadow_key = key.replace('.model.','.') if '.model.' in key else key + m_param[key].data.copy_(shadow_params[self.m_name2s_name[shadow_key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 00000000..567c0703 --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a requirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py new file mode 100644 index 00000000..5fc15bf9 --- /dev/null +++ b/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 00000000..c3ce1ac5 --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,206 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial +import logging + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + +logger = logging.getLogger(__name__) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Can't encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/models/first_stage_models/kl-f16/config.yaml b/models/first_stage_models/kl-f16/config.yaml new file mode 100644 index 00000000..661921cf --- /dev/null +++ b/models/first_stage_models/kl-f16/config.yaml @@ -0,0 +1,44 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 16 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + ddconfig: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 6 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/models/first_stage_models/kl-f32/config.yaml b/models/first_stage_models/kl-f32/config.yaml new file mode 100644 index 00000000..7b642b13 --- /dev/null +++ b/models/first_stage_models/kl-f32/config.yaml @@ -0,0 +1,46 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 64 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + ddconfig: + double_z: true + z_channels: 64 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + - 8 + dropout: 0.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 6 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/models/first_stage_models/kl-f4/config.yaml b/models/first_stage_models/kl-f4/config.yaml new file mode 100644 index 00000000..85cfb3e9 --- /dev/null +++ b/models/first_stage_models/kl-f4/config.yaml @@ -0,0 +1,41 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 3 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + ddconfig: + double_z: true + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 10 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/models/first_stage_models/kl-f8/config.yaml b/models/first_stage_models/kl-f8/config.yaml new file mode 100644 index 00000000..921aa425 --- /dev/null +++ b/models/first_stage_models/kl-f8/config.yaml @@ -0,0 +1,42 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 4 + lossconfig: + target: ldm.modules.losses.LPIPSWithDiscriminator + params: + disc_start: 50001 + kl_weight: 1.0e-06 + disc_weight: 0.5 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/models/first_stage_models/vq-f16/config.yaml b/models/first_stage_models/vq-f16/config.yaml new file mode 100644 index 00000000..91c74549 --- /dev/null +++ b/models/first_stage_models/vq-f16/config.yaml @@ -0,0 +1,49 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 8 + n_embed: 16384 + ddconfig: + double_z: false + z_channels: 8 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 250001 + disc_weight: 0.75 + disc_num_layers: 2 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 14 + num_workers: 20 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/models/first_stage_models/vq-f4-noattn/config.yaml b/models/first_stage_models/vq-f4-noattn/config.yaml new file mode 100644 index 00000000..f8e499fa --- /dev/null +++ b/models/first_stage_models/vq-f4-noattn/config.yaml @@ -0,0 +1,46 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + + ddconfig: + attn_type: none + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 11 + disc_weight: 0.75 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 12 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + crop_size: 256 diff --git a/models/first_stage_models/vq-f4/config.yaml b/models/first_stage_models/vq-f4/config.yaml new file mode 100644 index 00000000..7d8cef32 --- /dev/null +++ b/models/first_stage_models/vq-f4/config.yaml @@ -0,0 +1,45 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 0 + disc_weight: 0.75 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 16 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + crop_size: 256 diff --git a/models/first_stage_models/vq-f8-n256/config.yaml b/models/first_stage_models/vq-f8-n256/config.yaml new file mode 100644 index 00000000..8519e13d --- /dev/null +++ b/models/first_stage_models/vq-f8-n256/config.yaml @@ -0,0 +1,48 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 4 + n_embed: 256 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_start: 250001 + disc_weight: 0.75 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 10 + num_workers: 20 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/models/first_stage_models/vq-f8/config.yaml b/models/first_stage_models/vq-f8/config.yaml new file mode 100644 index 00000000..efd6801c --- /dev/null +++ b/models/first_stage_models/vq-f8/config.yaml @@ -0,0 +1,48 @@ +model: + base_learning_rate: 4.5e-06 + target: ldm.models.autoencoder.VQModel + params: + embed_dim: 4 + n_embed: 16384 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: false + disc_in_channels: 3 + disc_num_layers: 2 + disc_start: 1 + disc_weight: 0.6 + codebook_weight: 1.0 +data: + target: main.DataModuleFromConfig + params: + batch_size: 10 + num_workers: 20 + wrap: true + train: + target: ldm.data.openimages.FullOpenImagesTrain + params: + size: 384 + crop_size: 256 + validation: + target: ldm.data.openimages.FullOpenImagesValidation + params: + size: 384 + crop_size: 256 diff --git a/models/ldm/bsr_sr/config.yaml b/models/ldm/bsr_sr/config.yaml new file mode 100644 index 00000000..861692a8 --- /dev/null +++ b/models/ldm/bsr_sr/config.yaml @@ -0,0 +1,80 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0155 + log_every_t: 100 + timesteps: 1000 + loss_type: l2 + first_stage_key: image + cond_stage_key: LR_image + image_size: 64 + channels: 3 + concat_mode: true + cond_stage_trainable: false + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 6 + out_channels: 3 + model_channels: 160 + attention_resolutions: + - 16 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 2 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: torch.nn.Identity +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + wrap: false + num_workers: 12 + train: + target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain + params: + size: 256 + degradation: bsrgan_light + downscale_f: 4 + min_crop_f: 0.5 + max_crop_f: 1.0 + random_crop: true + validation: + target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation + params: + size: 256 + degradation: bsrgan_light + downscale_f: 4 + min_crop_f: 0.5 + max_crop_f: 1.0 + random_crop: true diff --git a/models/ldm/celeba256/config.yaml b/models/ldm/celeba256/config.yaml new file mode 100644 index 00000000..a12f4e9d --- /dev/null +++ b/models/ldm/celeba256/config.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 64 + channels: 3 + cond_stage_trainable: false + concat_mode: false + monitor: val/loss + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: ldm.data.faceshq.CelebAHQTrain + params: + size: 256 + validation: + target: ldm.data.faceshq.CelebAHQValidation + params: + size: 256 diff --git a/models/ldm/cin256/config.yaml b/models/ldm/cin256/config.yaml new file mode 100644 index 00000000..9bc1b456 --- /dev/null +++ b/models/ldm/cin256/config.yaml @@ -0,0 +1,80 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 32 + channels: 4 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 256 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 512 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 4 + n_embed: 16384 + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 512 + key: class_label +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + num_workers: 12 + wrap: false + train: + target: ldm.data.imagenet.ImageNetTrain + params: + config: + size: 256 + validation: + target: ldm.data.imagenet.ImageNetValidation + params: + config: + size: 256 diff --git a/models/ldm/ffhq256/config.yaml b/models/ldm/ffhq256/config.yaml new file mode 100644 index 00000000..0ddfd1b9 --- /dev/null +++ b/models/ldm/ffhq256/config.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 64 + channels: 3 + cond_stage_trainable: false + concat_mode: false + monitor: val/loss + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 42 + num_workers: 5 + wrap: false + train: + target: ldm.data.faceshq.FFHQTrain + params: + size: 256 + validation: + target: ldm.data.faceshq.FFHQValidation + params: + size: 256 diff --git a/models/ldm/inpainting_big/config.yaml b/models/ldm/inpainting_big/config.yaml new file mode 100644 index 00000000..da5fd5ea --- /dev/null +++ b/models/ldm/inpainting_big/config.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: masked_image + image_size: 64 + channels: 3 + concat_mode: true + monitor: val/loss + scheduler_config: + target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler + params: + verbosity_interval: 0 + warm_up_steps: 1000 + max_decay_steps: 50000 + lr_start: 0.001 + lr_max: 0.1 + lr_min: 0.0001 + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 7 + out_channels: 3 + model_channels: 256 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_heads: 8 + resblock_updown: true + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + ddconfig: + attn_type: none + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: ldm.modules.losses.contperceptual.DummyLoss + cond_stage_config: __is_first_stage__ diff --git a/models/ldm/layout2img-openimages256/config.yaml b/models/ldm/layout2img-openimages256/config.yaml new file mode 100644 index 00000000..9e1dc15f --- /dev/null +++ b/models/ldm/layout2img-openimages256/config.yaml @@ -0,0 +1,81 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: coordinates_bbox + image_size: 64 + channels: 3 + conditioning_key: crossattn + cond_stage_trainable: true + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 128 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 3 + context_dim: 512 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.BERTEmbedder + params: + n_embed: 512 + n_layer: 16 + vocab_size: 8192 + max_seq_len: 92 + use_tokenizer: false + monitor: val/loss_simple_ema +data: + target: main.DataModuleFromConfig + params: + batch_size: 24 + wrap: false + num_workers: 10 + train: + target: ldm.data.openimages.OpenImagesBBoxTrain + params: + size: 256 + validation: + target: ldm.data.openimages.OpenImagesBBoxValidation + params: + size: 256 diff --git a/models/ldm/lsun_beds256/config.yaml b/models/ldm/lsun_beds256/config.yaml new file mode 100644 index 00000000..1a50c766 --- /dev/null +++ b/models/ldm/lsun_beds256/config.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 64 + channels: 3 + cond_stage_trainable: false + concat_mode: false + monitor: val/loss + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: ldm.data.lsun.LSUNBedroomsTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNBedroomsValidation + params: + size: 256 diff --git a/models/ldm/lsun_churches256/config.yaml b/models/ldm/lsun_churches256/config.yaml new file mode 100644 index 00000000..424d0914 --- /dev/null +++ b/models/ldm/lsun_churches256/config.yaml @@ -0,0 +1,92 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0155 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: image + image_size: 32 + channels: 4 + cond_stage_trainable: false + concat_mode: false + scale_by_std: true + monitor: val/loss_simple_ema + scheduler_config: + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: + - 10000 + cycle_lengths: + - 10000000000000 + f_start: + - 1.0e-06 + f_max: + - 1.0 + f_min: + - 1.0 + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 192 + attention_resolutions: + - 1 + - 2 + - 4 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 2 + - 4 + - 4 + num_heads: 8 + use_scale_shift_norm: true + resblock_updown: true + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: '__is_unconditional__' + +data: + target: main.DataModuleFromConfig + params: + batch_size: 96 + num_workers: 5 + wrap: false + train: + target: ldm.data.lsun.LSUNChurchesTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNChurchesValidation + params: + size: 256 diff --git a/models/ldm/semantic_synthesis256/config.yaml b/models/ldm/semantic_synthesis256/config.yaml new file mode 100644 index 00000000..1a721cff --- /dev/null +++ b/models/ldm/semantic_synthesis256/config.yaml @@ -0,0 +1,59 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: segmentation + image_size: 64 + channels: 3 + concat_mode: true + cond_stage_trainable: true + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 6 + out_channels: 3 + model_channels: 128 + attention_resolutions: + - 32 + - 16 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 4 + - 8 + num_heads: 8 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.SpatialRescaler + params: + n_stages: 2 + in_channels: 182 + out_channels: 3 diff --git a/models/ldm/semantic_synthesis512/config.yaml b/models/ldm/semantic_synthesis512/config.yaml new file mode 100644 index 00000000..8faded2e --- /dev/null +++ b/models/ldm/semantic_synthesis512/config.yaml @@ -0,0 +1,78 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: segmentation + image_size: 128 + channels: 3 + concat_mode: true + cond_stage_trainable: true + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 128 + in_channels: 6 + out_channels: 3 + model_channels: 128 + attention_resolutions: + - 32 + - 16 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 4 + - 8 + num_heads: 8 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + monitor: val/rec_loss + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.SpatialRescaler + params: + n_stages: 2 + in_channels: 182 + out_channels: 3 +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + wrap: false + num_workers: 10 + train: + target: ldm.data.landscapes.RFWTrain + params: + size: 768 + crop_size: 512 + segmentation_to_float32: true + validation: + target: ldm.data.landscapes.RFWValidation + params: + size: 768 + crop_size: 512 + segmentation_to_float32: true diff --git a/models/ldm/text2img256/config.yaml b/models/ldm/text2img256/config.yaml new file mode 100644 index 00000000..3f54a015 --- /dev/null +++ b/models/ldm/text2img256/config.yaml @@ -0,0 +1,77 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 64 + channels: 3 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 192 + attention_resolutions: + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 5 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 640 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.BERTEmbedder + params: + n_embed: 640 + n_layer: 32 +data: + target: main.DataModuleFromConfig + params: + batch_size: 28 + num_workers: 10 + wrap: false + train: + target: ldm.data.previews.pytorch_dataset.PreviewsTrain + params: + size: 256 + validation: + target: ldm.data.previews.pytorch_dataset.PreviewsValidation + params: + size: 256