diff --git a/bins/sgmse/inference.py b/bins/sgmse/inference.py new file mode 100644 index 00000000..d3bdf2ad --- /dev/null +++ b/bins/sgmse/inference.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from argparse import ArgumentParser +import os + +from models.sgmse.dereverberation.dereverberation_inference import ( + DereverberationInference, +) +from utils.util import save_config, load_model_config, load_config +import numpy as np +import torch + + +def build_inference(args, cfg): + supported_inference = { + "dereverberation": DereverberationInference, + } + + inference_class = supported_inference[cfg.model_type] + inference = inference_class(args, cfg) + return inference + + +def build_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config", + type=str, + required=True, + help="JSON/YAML file for configurations.", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + ) + parser.add_argument( + "--test_dir", + type=str, + required=True, + help="Directory containing the test data (must have subdirectory noisy/)", + ) + parser.add_argument( + "--corrector_steps", type=int, default=1, help="Number of corrector steps" + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Output dir for saving generated results", + ) + parser.add_argument( + "--snr", + type=float, + default=0.33, + help="SNR value for (annealed) Langevin dynmaics.", + ) + parser.add_argument("--N", type=int, default=50, help="Number of reverse steps") + parser.add_argument("--local_rank", default=0, type=int) + return parser + + +def main(): + # Parse arguments + args = build_parser().parse_args() + # args, infer_type = formulate_parser(args) + + # Parse config + cfg = load_config(args.config) + if torch.cuda.is_available(): + args.local_rank = torch.device("cuda") + else: + args.local_rank = torch.device("cpu") + print("args: ", args) + + # Build inference + inferencer = build_inference(args, cfg) + + # Run inference + inferencer.inference() + + +if __name__ == "__main__": + main() diff --git a/bins/sgmse/preprocess.py b/bins/sgmse/preprocess.py new file mode 100644 index 00000000..22e1e860 --- /dev/null +++ b/bins/sgmse/preprocess.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import faulthandler + +faulthandler.enable() +import os +import argparse +import json +from multiprocessing import cpu_count +from utils.util import load_config +from preprocessors.processor import preprocess_dataset + + +def preprocess(cfg): + """Proprocess raw data of single or multiple datasets (in cfg.dataset) + + Args: + cfg (dict): dictionary that stores configurations + """ + # Specify the output root path to save the processed data + output_path = cfg.preprocess.processed_dir + os.makedirs(output_path, exist_ok=True) + + ## Split train and test sets + for dataset in cfg.dataset: + print("Preprocess {}...".format(dataset)) + + preprocess_dataset( + dataset, + cfg.dataset_path[dataset], + output_path, + cfg.preprocess, + cfg.task_type, + is_custom_dataset=dataset in cfg.use_custom_dataset, + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", default="config.json", help="json files for configurations." + ) + parser.add_argument("--num_workers", type=int, default=int(cpu_count())) + args = parser.parse_args() + cfg = load_config(args.config) + preprocess(cfg) + + +if __name__ == "__main__": + main() diff --git a/bins/sgmse/train_sgmse.py b/bins/sgmse/train_sgmse.py new file mode 100644 index 00000000..11a7a004 --- /dev/null +++ b/bins/sgmse/train_sgmse.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import torch +from models.sgmse.dereverberation.dereverberation_Trainer import DereverberationTrainer + +from utils.util import load_config + + +def build_trainer(args, cfg): + supported_trainer = { + "dereverberation": DereverberationTrainer, + } + + trainer_class = supported_trainer[cfg.model_type] + trainer = trainer_class(args, cfg) + return trainer + + +def cuda_relevant(deterministic=False): + torch.cuda.empty_cache() + # TF32 on Ampere and above + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.enabled = True + torch.backends.cudnn.allow_tf32 = True + # Deterministic + torch.backends.cudnn.deterministic = deterministic + torch.backends.cudnn.benchmark = not deterministic + torch.use_deterministic_algorithms(deterministic) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="config.json", + help="json files for configurations.", + required=True, + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="Number of dataloader workers." + ) + parser.add_argument( + "--exp_name", + type=str, + default="exp_name", + help="A specific name to note the experiment", + required=True, + ) + parser.add_argument( + "--log_level", default="warning", help="logging level (debug, info, warning)" + ) + parser.add_argument("--stdout_interval", default=5, type=int) + parser.add_argument("--local_rank", default=0, type=int) + args = parser.parse_args() + cfg = load_config(args.config) + cfg.exp_name = args.exp_name + args.log_dir = os.path.join(cfg.log_dir, args.exp_name) + os.makedirs(args.log_dir, exist_ok=True) + # Data Augmentation + if cfg.preprocess.data_augment: + new_datasets_list = [] + for dataset in cfg.preprocess.data_augment: + new_datasets = [ + # f"{dataset}_pitch_shift", + # f"{dataset}_formant_shift", + f"{dataset}_equalizer", + f"{dataset}_time_stretch", + ] + new_datasets_list.extend(new_datasets) + cfg.dataset.extend(new_datasets_list) + + # CUDA settings + cuda_relevant() + + # Build trainer + trainer = build_trainer(args, cfg) + + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/config/sgmse.json b/config/sgmse.json new file mode 100644 index 00000000..000eb3ab --- /dev/null +++ b/config/sgmse.json @@ -0,0 +1,42 @@ +{ + "base_config": "config/base.json", + "dataset": [ + "wsj0reverb" + ], + "task_type": "sgmse", + "preprocess": { + "dummy": false, + "num_frames":256, + "normalize": "noisy", + "hop_length": 128, + "n_fft": 510, + "spec_abs_exponent": 0.5, + "spec_factor": 0.15, + "use_spkid": false, + "use_uv": false, + "use_frame_pitch": false, + "use_phone_pitch": false, + "use_frame_energy": false, + "use_phone_energy": false, + "use_mel": false, + "use_audio": false, + "use_label": false, + "use_one_hot": false + }, + "model": { + "sgmse": { + "backbone": "ncsnpp", + "sde": "ouve", + + "gpus": 1 + } + }, + "train": { + "batch_size": 8, + "lr": 1e-4, + "ema_decay": 0.999, + "t_eps": 3e-2, + "num_eval_files": 20 + } + +} \ No newline at end of file diff --git a/egs/sgmse/README.md b/egs/sgmse/README.md new file mode 100644 index 00000000..4765f4a1 --- /dev/null +++ b/egs/sgmse/README.md @@ -0,0 +1,98 @@ +# Amphion Speech Enhancement and Dereverberation with Diffusion-based Generative Models Recipe + + +
+
+ +
+
+This repository contains the PyTorch implementations for the 2023 papers and also adapted from [sgmse](https://github.com/sp-uhh/sgmse): +- Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann. [*"Speech Enhancement and Dereverberation with Diffusion-Based Generative Models"*](https://ieeexplore.ieee.org/abstract/document/10149431), IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023. + + +You can use any sgmse architecture with any dataset you want. There are three steps in total: + +1. Data preparation +2. Training +3. Inference + + +> **NOTE:** You need to run every command of this recipe in the `Amphion` root path: +> ```bash +> cd Amphion +> ``` + +## 1. Data Preparation + +You can train the vocoder with any datasets. Amphion's supported open-source datasets are detailed [here](../../../datasets/README.md). + +### Configuration + +Specify the dataset path in `exp_config_base.json`. Note that you can change the `dataset` list to use your preferred datasets. + +```json +"dataset": [ + "wsj0reverb" + ], + "dataset_path": { + // TODO: Fill in your dataset path + "wsj0reverb": "" + }, +"preprocess": { + "processed_dir": "", + "sample_rate": 16000 + }, +``` + +## 2. Training + +### Configuration + +We provide the default hyparameters in the `exp_config_base.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines. + +```json + "train": { + // TODO: Fill in your checkpoint path + "checkpoint": "", + "adam": { + "lr": 1e-4 + }, + "ddp": false, + "batch_size": 8, + "epochs": 200000, + "save_checkpoints_steps": 800, + "save_summary_steps": 1000, + "max_steps": 1000000, + "ema_decay": 0.999, + "valid_interval": 800, + "t_eps": 3e-2, + "num_eval_files": 20 + + } +} +``` + +### Run + +Run the `run.sh` as the training stage (set `--stage 2`). + +```bash +sh egs/sgmse/dereverberation/run.sh --stage 2 +``` + +> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`. + +## 3. Inference + +### Run + +Run the `run.sh` as the training stage (set `--stage 3`) + +```bash +sh egs/sgmse/dereverberation/run.sh --stage 3 + --checkpoint_path [your path] + --test_dir [your path] + --output_dir [your path] + +``` + diff --git a/egs/sgmse/dereverberation/exp_config.json b/egs/sgmse/dereverberation/exp_config.json new file mode 100644 index 00000000..cf5c4872 --- /dev/null +++ b/egs/sgmse/dereverberation/exp_config.json @@ -0,0 +1,69 @@ +{ + "base_config": "config/sgmse.json", + "model_type": "dereverberation", + "dataset": [ + "wsj0reverb" + ], + "dataset_path": { + // TODO: Fill in your dataset path + "wsj0reverb": "" + }, + "log_dir": "", + "preprocess": { + "processed_dir": "", + "sample_rate": 16000 + }, + "model": { + "sgmse": { + "backbone": "ncsnpp", + "sde": "ouve", + "ncsnpp": { + "scale_by_sigma": true, + "nonlinearity": "swish", + "nf": 128, + "ch_mult": [1, 1, 2, 2, 2, 2, 2], + "num_res_blocks": 2, + "attn_resolutions": [16], + "resamp_with_conv": true, + "conditional": true, + "fir": true, + "fir_kernel": [1, 3, 3, 1], + "skip_rescale": true, + "resblock_type": "biggan", + "progressive": "output_skip", + "progressive_input": "input_skip", + "progressive_combine": "sum", + "init_scale": 0.0, + "fourier_scale": 16, + "image_size": 256, + "embedding_type": "fourier", + "dropout": 0.0, + "centered": true + }, + "ouve": { + "theta": 1.5, + "sigma_min": 0.05, + "sigma_max": 0.5, + "N": 1000 + }, + "gpus": 1 + } + }, + "train": { + "checkpoint": "", + "adam": { + "lr": 1e-4 + }, + "ddp": false, + "batch_size": 8, + "epochs": 200000, + "save_checkpoints_steps": 800, + "save_summary_steps": 1000, + "max_steps": 1000000, + "ema_decay": 0.999, + "valid_interval": 800, + "t_eps": 3e-2, + "num_eval_files": 20 + + } +} diff --git a/egs/sgmse/dereverberation/run.sh b/egs/sgmse/dereverberation/run.sh new file mode 100644 index 00000000..b1aa00f4 --- /dev/null +++ b/egs/sgmse/dereverberation/run.sh @@ -0,0 +1,97 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +######## Build Experiment Environment ########### +exp_dir=$(cd `dirname $0`; pwd) +work_dir=$(dirname $(dirname $(dirname $exp_dir))) + + +export WORK_DIR=$work_dir +export PYTHONPATH=$work_dir +export PYTHONIOENCODING=UTF-8 +export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:100" + +######## Parse the Given Parameters from the Commond ########### +options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,checkpoint:,resume_type:,main_process_port:,infer_mode:,infer_datasets:,infer_feature_dir:,infer_audio_dir:,infer_expt_dir:,infer_output_dir:,checkpoint_path:,test_dir:,output_dir: -- "$@") +eval set -- "$options" +export CUDA_VISIBLE_DEVICES="0" +while true; do + case $1 in + # Experimental Configuration File + -c | --config) shift; exp_config=$1 ; shift ;; + # Experimental Name + -n | --name) shift; exp_name=$1 ; shift ;; + # Running Stage + -s | --stage) shift; running_stage=$1 ; shift ;; + # Visible GPU machines. The default value is "0". + --gpu) shift; gpu=$1 ; shift ;; + --checkpoint_path) shift; checkpoint_path=$1 ; shift ;; + --test_dir) shift; test_dir=$1 ; shift ;; + --output_dir) shift; output_dir=$1 ; shift ;; + # [Only for Training] The specific checkpoint path that you want to resume from. + --checkpoint) shift; checkpoint=$1 ; shift ;; + # [Only for Traiing] `main_process_port` for multi gpu training + --main_process_port) shift; main_process_port=$1 ; shift ;; + + # [Only for Inference] The inference mode + --infer_mode) shift; infer_mode=$1 ; shift ;; + # [Only for Inference] The inferenced datasets + --infer_datasets) shift; infer_datasets=$1 ; shift ;; + # [Only for Inference] The feature dir for inference + --infer_feature_dir) shift; infer_feature_dir=$1 ; shift ;; + # [Only for Inference] The audio dir for inference + --infer_audio_dir) shift; infer_audio_dir=$1 ; shift ;; + # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]" + --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;; + # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result" + --infer_output_dir) shift; infer_output_dir=$1 ; shift ;; + + --) shift ; break ;; + *) echo "Invalid option: $1" exit 1 ;; + esac +done + + +### Value check ### +if [ -z "$running_stage" ]; then + echo "[Error] Please specify the running stage" + exit 1 +fi + +if [ -z "$exp_config" ]; then + exp_config="${exp_dir}"/exp_config.json +fi +echo "Exprimental Configuration File: $exp_config" + +if [ -z "$gpu" ]; then + gpu="0" +fi + +if [ -z "$main_process_port" ]; then + main_process_port=29500 +fi +echo "Main Process Port: $main_process_port" + +######## Features Extraction ########### +if [ $running_stage -eq 1 ]; then + CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/sgmse/preprocess.py \ + --config $exp_config \ + --num_workers 8 +fi +######## Training ########### +if [ $running_stage -eq 2 ]; then + CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/sgmse/train_sgmse.py \ + --config "$exp_config" \ + --exp_name "$exp_name" \ + --log_level info + fi + +if [ $running_stage -eq 3 ]; then + CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/sgmse/inference.py \ + --config=$exp_config \ + --checkpoint_path=$checkpoint_path \ + --test_dir="$test_dir" \ + --output_dir=$output_dir + fi \ No newline at end of file diff --git a/env.sh b/env.sh index 10ef7ff1..1722c4ae 100644 --- a/env.sh +++ b/env.sh @@ -28,5 +28,6 @@ pip install phonemizer==3.2.1 pypinyin==0.48.0 pip install black==24.1.1 +pip install torch-ema ninja # Uninstall nvidia-cublas-cu11 if there exist some bugs about CUDA version # pip uninstall nvidia-cublas-cu11 diff --git a/imgs/sgmse/diffusion_process.png b/imgs/sgmse/diffusion_process.png new file mode 100644 index 00000000..6f8a6db0 Binary files /dev/null and b/imgs/sgmse/diffusion_process.png differ diff --git a/models/base/base_trainer.py b/models/base/base_trainer.py index 8782216d..5c18274c 100644 --- a/models/base/base_trainer.py +++ b/models/base/base_trainer.py @@ -78,9 +78,11 @@ def __init__(self, args, cfg): self.criterion = self.build_criterion() if isinstance(self.criterion, dict): for key, value in self.criterion.items(): - self.criterion[key].cuda(args.local_rank) + if not callable(value): + self.criterion[key].cuda(args.local_rank) else: - self.criterion.cuda(self.args.local_rank) + if not callable(self.criterion): + self.criterion.cuda(self.args.local_rank) # optimizer self.optimizer = self.build_optimizer() diff --git a/models/sgmse/dereverberation/__init__.py b/models/sgmse/dereverberation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/sgmse/dereverberation/dereverberation.py b/models/sgmse/dereverberation/dereverberation.py new file mode 100644 index 00000000..4d607096 --- /dev/null +++ b/models/sgmse/dereverberation/dereverberation.py @@ -0,0 +1,26 @@ +import torch +from torch import nn +from types import SimpleNamespace +from modules.sgmse.sdes import SDERegistry +from modules.sgmse.shared import BackboneRegistry +import json + + +class ScoreModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + # Initialize Backbone DNN + dnn_cls = BackboneRegistry.get_by_name(cfg.backbone) + dnn_cfg = cfg[cfg.backbone] + self.dnn = dnn_cls(**dnn_cfg) + # Initialize SDE + sde_cls = SDERegistry.get_by_name(cfg.sde) + sde_cfg = cfg[cfg.sde] + self.sde = sde_cls(**sde_cfg) + + def forward(self, x, t, y): + # Concatenate y as an extra channel + dnn_input = torch.cat([x, y], dim=1) + score = -self.dnn(dnn_input, t) + return score diff --git a/models/sgmse/dereverberation/dereverberation_Trainer.py b/models/sgmse/dereverberation/dereverberation_Trainer.py new file mode 100644 index 00000000..ccf97263 --- /dev/null +++ b/models/sgmse/dereverberation/dereverberation_Trainer.py @@ -0,0 +1,214 @@ +from models.base.base_trainer import BaseTrainer +from models.sgmse.dereverberation.dereverberation_dataset import Specs +import torch +import torch.nn as nn +from torch.nn import MSELoss, L1Loss +import torch.nn.functional as F +from utils.sgmse_util.inference import evaluate_model +from torch_ema import ExponentialMovingAverage +from models.sgmse.dereverberation.dereverberation import ScoreModel +from modules.sgmse import sampling +from torch.utils.data import DataLoader +import os + + +class DereverberationTrainer(BaseTrainer): + def __init__(self, args, cfg): + BaseTrainer.__init__(self, args, cfg) + self.cfg = cfg + self.save_config_file() + self.ema = ExponentialMovingAverage( + self.model.parameters(), decay=self.cfg.train.ema_decay + ) + self._error_loading_ema = False + self.t_eps = self.cfg.train.t_eps + self.num_eval_files = self.cfg.train.num_eval_files + self.data_loader = self.build_data_loader() + self.save_config_file() + + checkpoint = self.load_checkpoint() + if checkpoint: + self.load_model(checkpoint) + + def build_dataset(self): + return Specs + + def load_checkpoint(self): + model_path = self.cfg.train.checkpoint + if not model_path or not os.path.exists(model_path): + self.logger.info("No checkpoint to load or checkpoint path does not exist.") + return None + if not self.cfg.train.ddp or self.args.local_rank == 0: + self.logger.info(f"Re(store) from {model_path}") + checkpoint = torch.load(model_path, map_location="cpu") + if "ema" in checkpoint: + try: + self.ema.load_state_dict(checkpoint["ema"]) + except: + self._error_loading_ema = True + warnings.warn("EMA state_dict not found in checkpoint!") + return checkpoint + + def build_data_loader(self): + Dataset = self.build_dataset() + train_set = Dataset(self.cfg, subset="train", shuffle_spec=True) + train_loader = DataLoader( + train_set, + batch_size=self.cfg.train.batch_size, + num_workers=self.args.num_workers, + pin_memory=False, + shuffle=True, + ) + self.valid_set = Dataset(self.cfg, subset="valid", shuffle_spec=False) + valid_loader = DataLoader( + self.valid_set, + batch_size=self.cfg.train.batch_size, + num_workers=self.args.num_workers, + pin_memory=False, + shuffle=False, + ) + data_loader = {"train": train_loader, "valid": valid_loader} + return data_loader + + def build_optimizer(self): + optimizer = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam) + return optimizer + + def build_scheduler(self): + return None + # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau) + + def build_singers_lut(self): + return None + + def write_summary(self, losses, stats): + for key, value in losses.items(): + self.sw.add_scalar(key, value, self.step) + + def write_valid_summary(self, losses, stats): + for key, value in losses.items(): + self.sw.add_scalar(key, value, self.step) + + def _loss(self, err): + losses = torch.square(err.abs()) + loss = torch.mean(0.5 * torch.sum(losses.reshape(losses.shape[0], -1), dim=-1)) + return loss + + def build_criterion(self): + return self._loss + + def get_state_dict(self): + state_dict = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "step": self.step, + "epoch": self.epoch, + "batch_size": self.cfg.train.batch_size, + "ema": self.ema.state_dict(), + } + if self.scheduler is not None: + state_dict["scheduler"] = self.scheduler.state_dict() + return state_dict + + def load_model(self, checkpoint): + self.step = checkpoint["step"] + self.epoch = checkpoint["epoch"] + + self.model.load_state_dict(checkpoint["model"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if "scheduler" in checkpoint and self.scheduler is not None: + self.scheduler.load_state_dict(checkpoint["scheduler"]) + if "ema" in checkpoint: + self.ema.load_state_dict(checkpoint["ema"]) + + def build_model(self): + self.model = ScoreModel(self.cfg.model.sgmse) + return self.model + + def get_pc_sampler( + self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs + ): + N = self.model.sde.N if N is None else N + sde = self.model.sde.copy() + sde.N = N + + kwargs = {"eps": self.t_eps, **kwargs} + if minibatch is None: + return sampling.get_pc_sampler( + predictor_name, + corrector_name, + sde=sde, + score_fn=self.model, + y=y, + **kwargs, + ) + else: + M = y.shape[0] + + def batched_sampling_fn(): + samples, ns = [], [] + for i in range(int(ceil(M / minibatch))): + y_mini = y[i * minibatch : (i + 1) * minibatch] + sampler = sampling.get_pc_sampler( + predictor_name, + corrector_name, + sde=sde, + score_fn=self.model, + y=y_mini, + **kwargs, + ) + sample, n = sampler() + samples.append(sample) + ns.append(n) + samples = torch.cat(samples, dim=0) + return samples, ns + + return batched_sampling_fn + + def _step(self, batch): + x = batch["X"] + y = batch["Y"] + + t = ( + torch.rand(x.shape[0], device=x.device) * (self.model.sde.T - self.t_eps) + + self.t_eps + ) + mean, std = self.model.sde.marginal_prob(x, t, y) + + z = torch.randn_like(x) + sigmas = std[:, None, None, None] + perturbed_data = mean + sigmas * z + score = self.model(perturbed_data, t, y) + + err = score * sigmas + z + loss = self.criterion(err) + return loss + + def train_step(self, batch): + loss = self._step(batch) + + # Backward pass and optimization + self.optimizer.zero_grad() # reset gradient + loss.backward() + self.optimizer.step() + + # Update the EMA of the model parameters + self.ema.update(self.model.parameters()) + + self.write_summary({"train_loss": loss.item()}, {}) + return {"train_loss": loss.item()}, {}, loss.item() + + def eval_step(self, batch, batch_idx): + self.ema.store(self.model.parameters()) + self.ema.copy_to(self.model.parameters()) + loss = self._step(batch) + self.write_valid_summary({"valid_loss": loss.item()}, {}) + if batch_idx == 0 and self.num_eval_files != 0: + pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files) + self.write_valid_summary( + {"pesq": pesq, "si_sdr": si_sdr, "estoi": estoi}, {} + ) + print(f" pesq={pesq}, si_sdr={si_sdr}, estoi={estoi}") + if self.ema.collected_params is not None: + self.ema.restore(self.model.parameters()) + return {"valid_loss": loss.item()}, {}, loss.item() diff --git a/models/sgmse/dereverberation/dereverberation_dataset.py b/models/sgmse/dereverberation/dereverberation_dataset.py new file mode 100644 index 00000000..00b39262 --- /dev/null +++ b/models/sgmse/dereverberation/dereverberation_dataset.py @@ -0,0 +1,114 @@ +import torch +from glob import glob +from torchaudio import load +import numpy as np +import torch.nn.functional as F +import os +from os.path import join + + +class Specs: + def __init__(self, cfg, subset, shuffle_spec): + self.cfg = cfg + self.data_dir = os.path.join( + cfg.preprocess.processed_dir, cfg.dataset[0], "audio" + ) + self.clean_files = sorted(glob(join(self.data_dir, subset) + "/anechoic/*.wav")) + self.noisy_files = sorted(glob(join(self.data_dir, subset) + "/reverb/*.wav")) + self.dummy = cfg.preprocess.dummy + self.num_frames = cfg.preprocess.num_frames + self.shuffle_spec = shuffle_spec + self.normalize = cfg.preprocess.normalize + self.hop_length = cfg.preprocess.hop_length + self.n_fft = cfg.preprocess.n_fft + self.window = self.get_window(self.n_fft) + self.windows = {} + self.spec_abs_exponent = cfg.preprocess.spec_abs_exponent + self.spec_factor = cfg.preprocess.spec_factor + + def __getitem__(self, i): + x, _ = load(self.clean_files[i]) + y, _ = load(self.noisy_files[i]) + + # formula applies for center=True + target_len = (self.num_frames - 1) * self.hop_length + current_len = x.size(-1) + pad = max(target_len - current_len, 0) + if pad == 0: + # extract random part of the audio file + if self.shuffle_spec: + start = int(np.random.uniform(0, current_len - target_len)) + else: + start = int((current_len - target_len) / 2) + x = x[..., start : start + target_len] + y = y[..., start : start + target_len] + else: + # pad audio if the length T is smaller than num_frames + x = F.pad(x, (pad // 2, pad // 2 + (pad % 2)), mode="constant") + y = F.pad(y, (pad // 2, pad // 2 + (pad % 2)), mode="constant") + + # normalize w.r.t to the noisy or the clean signal or not at all + # to ensure same clean signal power in x and y. + if self.normalize == "noisy": + normfac = y.abs().max() + elif self.normalize == "clean": + normfac = x.abs().max() + elif self.normalize == "not": + normfac = 1.0 + x = x / normfac + y = y / normfac + + X = torch.stft(x, **self.stft_kwargs()) + Y = torch.stft(y, **self.stft_kwargs()) + X, Y = self.spec_transform(X), self.spec_transform(Y) + return {"X": X, "Y": Y} + + def __len__(self): + if self.dummy: + return int(len(self.clean_files) / 200) + else: + return len(self.clean_files) + + def spec_transform(self, spec): + if self.spec_abs_exponent != 1: + e = self.spec_abs_exponent + spec = spec.abs() ** e * torch.exp(1j * spec.angle()) + spec = spec * self.spec_factor + + return spec + + def stft_kwargs(self): + return {**self.istft_kwargs(), "return_complex": True} + + def istft_kwargs(self): + return dict( + n_fft=self.n_fft, + hop_length=self.hop_length, + window=self.window, + center=True, + ) + + def stft(self, sig): + window = self._get_window(sig) + return torch.stft(sig, **{**self.stft_kwargs(), "window": window}) + + def istft(self, spec, length=None): + window = self._get_window(spec) + return torch.istft( + spec, **{**self.istft_kwargs(), "window": window, "length": length} + ) + + @staticmethod + def get_window(window_length): + return torch.hann_window(window_length, periodic=True) + + def _get_window(self, x): + """ + Retrieve an appropriate window for the given tensor x, matching the device. + Caches the retrieved windows so that only one window tensor will be allocated per device. + """ + window = self.windows.get(x.device, None) + if window is None: + window = self.window.to(x.device) + self.windows[x.device] = window + return window diff --git a/models/sgmse/dereverberation/dereverberation_inference.py b/models/sgmse/dereverberation/dereverberation_inference.py new file mode 100644 index 00000000..c3c847c1 --- /dev/null +++ b/models/sgmse/dereverberation/dereverberation_inference.py @@ -0,0 +1,81 @@ +import time +import numpy as np +import torch +from tqdm import tqdm +import torch.nn as nn +from collections import OrderedDict +from models.sgmse.dereverberation.dereverberation import ScoreModel +from models.sgmse.dereverberation.dereverberation_dataset import Specs +from models.sgmse.dereverberation.dereverberation_Trainer import DereverberationTrainer +import json +from os.path import join +import glob +from torchaudio import load +from soundfile import write +from utils.sgmse_util.other import ensure_dir, pad_spec + + +class DereverberationInference: + def __init__(self, args, cfg): + self.cfg = cfg + self.t_eps = self.cfg.train.t_eps + self.args = args + self.test_dir = args.test_dir + self.target_dir = self.args.output_dir + self.model = self.build_model() + self.load_state_dict() + + def build_model(self): + self.model = ScoreModel(self.cfg.model.sgmse) + return self.model + + def load_state_dict(self): + self.checkpoint_path = self.args.checkpoint_path + checkpoint = torch.load(self.checkpoint_path, map_location="cpu") + self.model.load_state_dict(checkpoint["model"]) + self.model.cuda(self.args.local_rank) + + def inference(self): + sr = 16000 + snr = self.args.snr + N = self.args.N + corrector_steps = self.args.corrector_steps + self.model.eval() + noisy_dir = join(self.test_dir, "noisy/") + noisy_files = sorted(glob.glob("{}/*.wav".format(noisy_dir))) + for noisy_file in tqdm(noisy_files): + filename = noisy_file.split("/")[-1] + + # Load wav + y, _ = load(noisy_file) + T_orig = y.size(1) + + # Normalize + norm_factor = y.abs().max() + y = y / norm_factor + + # Prepare DNN input + spec = Specs(self.cfg, subset="", shuffle_spec=False) + Y = torch.unsqueeze(spec.spec_transform(spec.stft(sig=y.cuda())), 0) + Y = pad_spec(Y) + + # Reverse sampling + sampler = DereverberationTrainer.get_pc_sampler( + self, + "reverse_diffusion", + "ald", + Y.cuda(), + N=N, + corrector_steps=corrector_steps, + snr=snr, + ) + sample, _ = sampler() + + # Backward transform in time domain + x_hat = spec.istft(sample.squeeze(), T_orig) + + # Renormalize + x_hat = x_hat * norm_factor + + # Write enhanced wav file + write(join(self.target_dir, filename), x_hat.cpu().numpy(), 16000) diff --git a/modules/sgmse/__init__.py b/modules/sgmse/__init__.py new file mode 100644 index 00000000..ff6c52ef --- /dev/null +++ b/modules/sgmse/__init__.py @@ -0,0 +1,5 @@ +from .shared import BackboneRegistry +from .ncsnpp import NCSNpp +from .dcunet import DCUNet + +__all__ = ["BackboneRegistry", "NCSNpp", "DCUNet"] diff --git a/modules/sgmse/dcunet.py b/modules/sgmse/dcunet.py new file mode 100644 index 00000000..9815a76e --- /dev/null +++ b/modules/sgmse/dcunet.py @@ -0,0 +1,765 @@ +from functools import partial +import numpy as np + +import torch +from torch import nn, Tensor +from torch.nn.modules.batchnorm import _BatchNorm + +from .shared import ( + BackboneRegistry, + ComplexConv2d, + ComplexConvTranspose2d, + ComplexLinear, + DiffusionStepEmbedding, + GaussianFourierProjection, + FeatureMapDense, + torch_complex_from_reim, +) + + +def get_activation(name): + if name == "silu": + return nn.SiLU + elif name == "relu": + return nn.ReLU + elif name == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError(f"Unknown activation: {name}") + + +class BatchNorm(_BatchNorm): + def _check_input_dim(self, input): + if input.dim() < 2 or input.dim() > 4: + raise ValueError( + "expected 4D or 3D input (got {}D input)".format(input.dim()) + ) + + +class OnReIm(nn.Module): + def __init__(self, module_cls, *args, **kwargs): + super().__init__() + self.re_module = module_cls(*args, **kwargs) + self.im_module = module_cls(*args, **kwargs) + + def forward(self, x): + return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag)) + + +# Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers! + + +def unet_decoder_args(encoders, *, skip_connections): + """Get list of decoder arguments for upsampling (right) side of a symmetric u-net, + given the arguments used to construct the encoder. + Args: + encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)): + List of arguments used to construct the encoders + skip_connections (bool): Whether to include skip connections in the + calculation of decoder input channels. + Return: + tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding): + Arguments to be used to construct decoders + """ + decoder_args = [] + for ( + enc_in_chan, + enc_out_chan, + enc_kernel_size, + enc_stride, + enc_padding, + enc_dilation, + ) in reversed(encoders): + if skip_connections and decoder_args: + skip_in_chan = enc_out_chan + else: + skip_in_chan = 0 + decoder_args.append( + ( + enc_out_chan + skip_in_chan, + enc_in_chan, + enc_kernel_size, + enc_stride, + enc_padding, + enc_dilation, + ) + ) + return tuple(decoder_args) + + +def make_unet_encoder_decoder_args(encoder_args, decoder_args): + encoder_args = tuple( + ( + in_chan, + out_chan, + tuple(kernel_size), + tuple(stride), + ( + tuple([n // 2 for n in kernel_size]) + if padding == "auto" + else tuple(padding) + ), + tuple(dilation), + ) + for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args + ) + + if decoder_args == "auto": + decoder_args = unet_decoder_args( + encoder_args, + skip_connections=True, + ) + else: + decoder_args = tuple( + ( + in_chan, + out_chan, + tuple(kernel_size), + tuple(stride), + tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding, + tuple(dilation), + output_padding, + ) + for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args + ) + + return encoder_args, decoder_args + + +DCUNET_ARCHITECTURES = { + "DCUNet-10": make_unet_encoder_decoder_args( + # Encoders: + # (in_chan, out_chan, kernel_size, stride, padding, dilation) + ( + (1, 32, (7, 5), (2, 2), "auto", (1, 1)), + (32, 64, (7, 5), (2, 2), "auto", (1, 1)), + (64, 64, (5, 3), (2, 2), "auto", (1, 1)), + (64, 64, (5, 3), (2, 2), "auto", (1, 1)), + (64, 64, (5, 3), (2, 1), "auto", (1, 1)), + ), + # Decoders: automatic inverse + "auto", + ), + "DCUNet-16": make_unet_encoder_decoder_args( + # Encoders: + # (in_chan, out_chan, kernel_size, stride, padding, dilation) + ( + (1, 32, (7, 5), (2, 2), "auto", (1, 1)), + (32, 32, (7, 5), (2, 1), "auto", (1, 1)), + (32, 64, (7, 5), (2, 2), "auto", (1, 1)), + (64, 64, (5, 3), (2, 1), "auto", (1, 1)), + (64, 64, (5, 3), (2, 2), "auto", (1, 1)), + (64, 64, (5, 3), (2, 1), "auto", (1, 1)), + (64, 64, (5, 3), (2, 2), "auto", (1, 1)), + (64, 64, (5, 3), (2, 1), "auto", (1, 1)), + ), + # Decoders: automatic inverse + "auto", + ), + "DCUNet-20": make_unet_encoder_decoder_args( + # Encoders: + # (in_chan, out_chan, kernel_size, stride, padding, dilation) + ( + (1, 32, (7, 1), (1, 1), "auto", (1, 1)), + (32, 32, (1, 7), (1, 1), "auto", (1, 1)), + (32, 64, (7, 5), (2, 2), "auto", (1, 1)), + (64, 64, (7, 5), (2, 1), "auto", (1, 1)), + (64, 64, (5, 3), (2, 2), "auto", (1, 1)), + (64, 64, (5, 3), (2, 1), "auto", (1, 1)), + (64, 64, (5, 3), (2, 2), "auto", (1, 1)), + (64, 64, (5, 3), (2, 1), "auto", (1, 1)), + (64, 64, (5, 3), (2, 2), "auto", (1, 1)), + (64, 90, (5, 3), (2, 1), "auto", (1, 1)), + ), + # Decoders: automatic inverse + "auto", + ), + "DilDCUNet-v2": make_unet_encoder_decoder_args( # architecture used in SGMSE / Interspeech paper + # Encoders: + # (in_chan, out_chan, kernel_size, stride, padding, dilation) + ( + (1, 32, (4, 4), (1, 1), "auto", (1, 1)), + (32, 32, (4, 4), (1, 1), "auto", (1, 1)), + (32, 32, (4, 4), (1, 1), "auto", (1, 1)), + (32, 64, (4, 4), (2, 1), "auto", (2, 1)), + (64, 128, (4, 4), (2, 2), "auto", (4, 1)), + (128, 256, (4, 4), (2, 2), "auto", (8, 1)), + ), + # Decoders: automatic inverse + "auto", + ), +} + + +@BackboneRegistry.register("dcunet") +class DCUNet(nn.Module): + @staticmethod + def add_argparse_args(parser): + parser.add_argument( + "--dcunet-architecture", + type=str, + default="DilDCUNet-v2", + choices=DCUNET_ARCHITECTURES.keys(), + help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.", + ) + parser.add_argument( + "--dcunet-time-embedding", + type=str, + choices=("gfp", "ds", "none"), + default="gfp", + help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.", + ) + parser.add_argument( + "--dcunet-temb-layers-global", + type=int, + default=1, + help="Number of global linear+activation layers for the time embedding. 1 by default.", + ) + parser.add_argument( + "--dcunet-temb-layers-local", + type=int, + default=1, + help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.", + ) + parser.add_argument( + "--dcunet-temb-activation", + type=str, + default="silu", + help="The (complex) activation to use between all (global&local) time embedding layers.", + ) + parser.add_argument( + "--dcunet-time-embedding-complex", + action="store_true", + help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.", + ) + parser.add_argument( + "--dcunet-fix-length", + type=str, + default="pad", + choices=("pad", "trim", "none"), + help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.", + ) + parser.add_argument( + "--dcunet-mask-bound", + type=str, + choices=("tanh", "sigmoid", "none"), + default="none", + help="DCUNet output bounding strategy. 'none' by default.", + ) + parser.add_argument( + "--dcunet-norm-type", + type=str, + choices=("bN", "CbN"), + default="bN", + help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.", + ) + parser.add_argument( + "--dcunet-activation", + type=str, + choices=("leaky_relu", "relu", "silu"), + default="leaky_relu", + help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.", + ) + return parser + + def __init__( + self, + dcunet_architecture: str = "DilDCUNet-v2", + dcunet_time_embedding: str = "gfp", + dcunet_temb_layers_global: int = 2, + dcunet_temb_layers_local: int = 1, + dcunet_temb_activation: str = "silu", + dcunet_time_embedding_complex: bool = False, + dcunet_fix_length: str = "pad", + dcunet_mask_bound: str = "none", + dcunet_norm_type: str = "bN", + dcunet_activation: str = "relu", + embed_dim: int = 128, + **kwargs, + ): + super().__init__() + + self.architecture = dcunet_architecture + self.fix_length_mode = ( + dcunet_fix_length if dcunet_fix_length != "none" else None + ) + self.norm_type = dcunet_norm_type + self.activation = dcunet_activation + self.input_channels = 2 # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN + self.time_embedding = ( + dcunet_time_embedding if dcunet_time_embedding != "none" else None + ) + self.time_embedding_complex = dcunet_time_embedding_complex + self.temb_layers_global = dcunet_temb_layers_global + self.temb_layers_local = dcunet_temb_layers_local + self.temb_activation = dcunet_temb_activation + conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture] + + # Replace `input_channels` in encoders config + _replaced_input_channels, *rest = conf_encoders[0] + encoders = ((self.input_channels, *rest), *conf_encoders[1:]) + decoders = conf_decoders + self.encoders_stride_product = np.prod( + [enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0 + ) + + # Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation) + encoder_decoder_kwargs = dict( + norm_type=self.norm_type, + activation=self.activation, + temb_layers=self.temb_layers_local, + temb_activation=self.temb_activation, + ) + + # Instantiate (global) time embedding layer + embed_ops = [] + if self.time_embedding is not None: + complex_valued = self.time_embedding_complex + if self.time_embedding == "gfp": + embed_ops += [ + GaussianFourierProjection( + embed_dim=embed_dim, complex_valued=complex_valued + ) + ] + encoder_decoder_kwargs["embed_dim"] = embed_dim + elif self.time_embedding == "ds": + embed_ops += [ + DiffusionStepEmbedding( + embed_dim=embed_dim, complex_valued=complex_valued + ) + ] + encoder_decoder_kwargs["embed_dim"] = embed_dim + + if self.time_embedding_complex: + assert self.time_embedding in ( + "gfp", + "ds", + ), "Complex timestep embedding only available for gfp and ds" + encoder_decoder_kwargs["complex_time_embedding"] = True + for _ in range(self.temb_layers_global): + embed_ops += [ + ComplexLinear(embed_dim, embed_dim, complex_valued=True), + OnReIm(get_activation(dcunet_temb_activation)), + ] + self.embed = nn.Sequential(*embed_ops) + + ### Instantiate DCUNet layers ### + output_layer = ComplexConvTranspose2d(*decoders[-1]) + encoders = [ + DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs) + for args in encoders + ] + decoders = [ + DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs) + for args in decoders[:-1] + ] + + self.mask_bound = dcunet_mask_bound if dcunet_mask_bound != "none" else None + if self.mask_bound is not None: + raise NotImplementedError( + "sorry, mask bounding not implemented at the moment" + ) + # TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument + # operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound)) + # output_layer = nn.Sequential(*[x for x in operations if x is not None]) + + assert len(encoders) == len(decoders) + 1 + self.encoders = nn.ModuleList(encoders) + self.decoders = nn.ModuleList(decoders) + self.output_layer = output_layer or nn.Identity() + + def forward(self, spec, t) -> Tensor: + """ + Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible + by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders, + and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time + strides of the encoders. + Args: + spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last. + Returns: + Tensor, of shape (batch, time) or (time). + """ + # TF-rep shape: (batch, self.input_channels, n_fft, frames) + # Estimate mask from time-frequency representation. + x_in = self.fix_input_dims(spec) + x = x_in + t_embed = self.embed(t + 0j) if self.time_embedding is not None else None + + enc_outs = [] + for idx, enc in enumerate(self.encoders): + x = enc(x, t_embed) + # UNet skip connection + enc_outs.append(x) + for enc_out, dec in zip(reversed(enc_outs[:-1]), self.decoders): + x = dec(x, t_embed, output_size=enc_out.shape) + x = torch.cat([x, enc_out], dim=1) + + output = self.output_layer(x, output_size=x_in.shape) + # output shape: (batch, 1, n_fft, frames) + output = self.fix_output_dims(output, spec) + return output + + def fix_input_dims(self, x): + return _fix_dcu_input_dims( + self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product) + ) + + def fix_output_dims(self, out, x): + return _fix_dcu_output_dims(self.fix_length_mode, out, x) + + +def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product): + """Pad or trim `x` to a length compatible with DCUNet.""" + freq_prod = int(encoders_stride_product[0]) + time_prod = int(encoders_stride_product[1]) + if (x.shape[2] - 1) % freq_prod: + raise TypeError( + f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by " + f"{freq_prod}, got {x.shape} instead" + ) + time_remainder = (x.shape[3] - 1) % time_prod + if time_remainder: + if fix_length_mode is None: + raise TypeError( + f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by " + f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument " + f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically." + ) + elif fix_length_mode == "pad": + pad_shape = [0, time_prod - time_remainder] + x = nn.functional.pad(x, pad_shape, mode="constant") + elif fix_length_mode == "trim": + pad_shape = [0, -time_remainder] + x = nn.functional.pad(x, pad_shape, mode="constant") + else: + raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'") + return x + + +def _fix_dcu_output_dims(fix_length_mode, out, x): + """Fix shape of `out` to the original shape of `x` by padding/cropping.""" + inp_len = x.shape[-1] + output_len = out.shape[-1] + return nn.functional.pad(out, [0, inp_len - output_len]) + + +def _get_norm(norm_type): + if norm_type == "CbN": + return ComplexBatchNorm + elif norm_type == "bN": + return partial(OnReIm, BatchNorm) + else: + raise NotImplementedError(f"Unknown norm type: {norm_type}") + + +class DCUNetComplexEncoderBlock(nn.Module): + def __init__( + self, + in_chan, + out_chan, + kernel_size, + stride, + padding, + dilation, + norm_type="bN", + activation="leaky_relu", + embed_dim=None, + complex_time_embedding=False, + temb_layers=1, + temb_activation="silu", + ): + super().__init__() + + self.in_chan = in_chan + self.out_chan = out_chan + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.temb_layers = temb_layers + self.temb_activation = temb_activation + self.complex_time_embedding = complex_time_embedding + + self.conv = ComplexConv2d( + in_chan, + out_chan, + kernel_size, + stride, + padding, + bias=norm_type is None, + dilation=dilation, + ) + self.norm = _get_norm(norm_type)(out_chan) + self.activation = OnReIm(get_activation(activation)) + self.embed_dim = embed_dim + if self.embed_dim is not None: + ops = [] + for _ in range(max(0, self.temb_layers - 1)): + ops += [ + ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True), + OnReIm(get_activation(self.temb_activation)), + ] + ops += [ + FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True), + OnReIm(get_activation(self.temb_activation)), + ] + self.embed_layer = nn.Sequential(*ops) + + def forward(self, x, t_embed): + y = self.conv(x) + if self.embed_dim is not None: + y = y + self.embed_layer(t_embed) + return self.activation(self.norm(y)) + + +class DCUNetComplexDecoderBlock(nn.Module): + def __init__( + self, + in_chan, + out_chan, + kernel_size, + stride, + padding, + dilation, + output_padding=(0, 0), + norm_type="bN", + activation="leaky_relu", + embed_dim=None, + temb_layers=1, + temb_activation="swish", + complex_time_embedding=False, + ): + super().__init__() + + self.in_chan = in_chan + self.out_chan = out_chan + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.output_padding = output_padding + self.complex_time_embedding = complex_time_embedding + self.temb_layers = temb_layers + self.temb_activation = temb_activation + + self.deconv = ComplexConvTranspose2d( + in_chan, + out_chan, + kernel_size, + stride, + padding, + output_padding, + dilation=dilation, + bias=norm_type is None, + ) + self.norm = _get_norm(norm_type)(out_chan) + self.activation = OnReIm(get_activation(activation)) + self.embed_dim = embed_dim + if self.embed_dim is not None: + ops = [] + for _ in range(max(0, self.temb_layers - 1)): + ops += [ + ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True), + OnReIm(get_activation(self.temb_activation)), + ] + ops += [ + FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True), + OnReIm(get_activation(self.temb_activation)), + ] + self.embed_layer = nn.Sequential(*ops) + + def forward(self, x, t_embed, output_size=None): + y = self.deconv(x, output_size=output_size) + if self.embed_dim is not None: + y = y + self.embed_layer(t_embed) + return self.activation(self.norm(y)) + + +# From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py +class ComplexBatchNorm(torch.nn.Module): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=False, + ): + super(ComplexBatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.Wrr = torch.nn.Parameter(torch.Tensor(num_features)) + self.Wri = torch.nn.Parameter(torch.Tensor(num_features)) + self.Wii = torch.nn.Parameter(torch.Tensor(num_features)) + self.Br = torch.nn.Parameter(torch.Tensor(num_features)) + self.Bi = torch.nn.Parameter(torch.Tensor(num_features)) + else: + self.register_parameter("Wrr", None) + self.register_parameter("Wri", None) + self.register_parameter("Wii", None) + self.register_parameter("Br", None) + self.register_parameter("Bi", None) + if self.track_running_stats: + self.register_buffer("RMr", torch.zeros(num_features)) + self.register_buffer("RMi", torch.zeros(num_features)) + self.register_buffer("RVrr", torch.ones(num_features)) + self.register_buffer("RVri", torch.zeros(num_features)) + self.register_buffer("RVii", torch.ones(num_features)) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("RMr", None) + self.register_parameter("RMi", None) + self.register_parameter("RVrr", None) + self.register_parameter("RVri", None) + self.register_parameter("RVii", None) + self.register_parameter("num_batches_tracked", None) + self.reset_parameters() + + def reset_running_stats(self): + if self.track_running_stats: + self.RMr.zero_() + self.RMi.zero_() + self.RVrr.fill_(1) + self.RVri.zero_() + self.RVii.fill_(1) + self.num_batches_tracked.zero_() + + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + self.Br.data.zero_() + self.Bi.data.zero_() + self.Wrr.data.fill_(1) + self.Wri.data.uniform_(-0.9, +0.9) # W will be positive-definite + self.Wii.data.fill_(1) + + def _check_input_dim(self, xr, xi): + assert xr.shape == xi.shape + assert xr.size(1) == self.num_features + + def forward(self, x): + xr, xi = x.real, x.imag + self._check_input_dim(xr, xi) + + exponential_average_factor = 0.0 + + if self.training and self.track_running_stats: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / self.num_batches_tracked.item() + else: # use exponential moving average + exponential_average_factor = self.momentum + + # + # NOTE: The precise meaning of the "training flag" is: + # True: Normalize using batch statistics, update running statistics + # if they are being collected. + # False: Normalize using running statistics, ignore batch statistics. + # + training = self.training or not self.track_running_stats + redux = [i for i in reversed(range(xr.dim())) if i != 1] + vdim = [1] * xr.dim() + vdim[1] = xr.size(1) + + # + # Mean M Computation and Centering + # + # Includes running mean update if training and running. + # + if training: + Mr, Mi = xr, xi + for d in redux: + Mr = Mr.mean(d, keepdim=True) + Mi = Mi.mean(d, keepdim=True) + if self.track_running_stats: + self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) + self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) + else: + Mr = self.RMr.view(vdim) + Mi = self.RMi.view(vdim) + xr, xi = xr - Mr, xi - Mi + + # + # Variance Matrix V Computation + # + # Includes epsilon numerical stabilizer/Tikhonov regularizer. + # Includes running variance update if training and running. + # + if training: + Vrr = xr * xr + Vri = xr * xi + Vii = xi * xi + for d in redux: + Vrr = Vrr.mean(d, keepdim=True) + Vri = Vri.mean(d, keepdim=True) + Vii = Vii.mean(d, keepdim=True) + if self.track_running_stats: + self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) + self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) + self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) + else: + Vrr = self.RVrr.view(vdim) + Vri = self.RVri.view(vdim) + Vii = self.RVii.view(vdim) + Vrr = Vrr + self.eps + Vri = Vri + Vii = Vii + self.eps + + # + # Matrix Inverse Square Root U = V^-0.5 + # + # sqrt of a 2x2 matrix, + # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix + tau = Vrr + Vii + delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1) + s = delta.sqrt() + t = (tau + 2 * s).sqrt() + + # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html + rst = (s * t).reciprocal() + Urr = (s + Vii) * rst + Uii = (s + Vrr) * rst + Uri = (-Vri) * rst + + # + # Optionally left-multiply U by affine weights W to produce combined + # weights Z, left-multiply the inputs by Z, then optionally bias them. + # + # y = Zx + B + # y = WUx + B + # y = [Wrr Wri][Urr Uri] [xr] + [Br] + # [Wir Wii][Uir Uii] [xi] [Bi] + # + if self.affine: + Wrr, Wri, Wii = ( + self.Wrr.view(vdim), + self.Wri.view(vdim), + self.Wii.view(vdim), + ) + Zrr = (Wrr * Urr) + (Wri * Uri) + Zri = (Wrr * Uri) + (Wri * Uii) + Zir = (Wri * Urr) + (Wii * Uri) + Zii = (Wri * Uri) + (Wii * Uii) + else: + Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii + + yr = (Zrr * xr) + (Zri * xi) + yi = (Zir * xr) + (Zii * xi) + + if self.affine: + yr = yr + self.Br.view(vdim) + yi = yi + self.Bi.view(vdim) + + return torch.view_as_complex(torch.stack([yr, yi], dim=-1)) + + def extra_repr(self): + return ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}".format(**self.__dict__) + ) diff --git a/modules/sgmse/ncsnpp.py b/modules/sgmse/ncsnpp.py new file mode 100644 index 00000000..4302f47a --- /dev/null +++ b/modules/sgmse/ncsnpp.py @@ -0,0 +1,497 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file + +from .ncsnpp_utils import layers, layerspp, normalization +import torch.nn as nn +import functools +import torch +import numpy as np + +from .shared import BackboneRegistry + +ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp +ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp +Combine = layerspp.Combine +conv3x3 = layerspp.conv3x3 +conv1x1 = layerspp.conv1x1 +get_act = layers.get_act +get_normalization = normalization.get_normalization +default_initializer = layers.default_init + + +@BackboneRegistry.register("ncsnpp") +class NCSNpp(nn.Module): + """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository""" + + @staticmethod + def add_argparse_args(parser): + parser.add_argument( + "--ch_mult", type=int, nargs="+", default=[1, 1, 2, 2, 2, 2, 2] + ) + parser.add_argument("--num_res_blocks", type=int, default=2) + parser.add_argument("--attn_resolutions", type=int, nargs="+", default=[16]) + parser.add_argument( + "--no-centered", + dest="centered", + action="store_false", + help="The data is not centered [-1, 1]", + ) + parser.add_argument( + "--centered", + dest="centered", + action="store_true", + help="The data is centered [-1, 1]", + ) + parser.set_defaults(centered=True) + return parser + + def __init__( + self, + scale_by_sigma=True, + nonlinearity="swish", + nf=128, + ch_mult=(1, 1, 2, 2, 2, 2, 2), + num_res_blocks=2, + attn_resolutions=(16,), + resamp_with_conv=True, + conditional=True, + fir=True, + fir_kernel=[1, 3, 3, 1], + skip_rescale=True, + resblock_type="biggan", + progressive="output_skip", + progressive_input="input_skip", + progressive_combine="sum", + init_scale=0.0, + fourier_scale=16, + image_size=256, + embedding_type="fourier", + dropout=0.0, + centered=True, + **unused_kwargs, + ): + super().__init__() + self.act = act = get_act(nonlinearity) + + self.nf = nf = nf + ch_mult = ch_mult + self.num_res_blocks = num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions = attn_resolutions + dropout = dropout + resamp_with_conv = resamp_with_conv + self.num_resolutions = num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [ + image_size // (2**i) for i in range(num_resolutions) + ] + + self.conditional = conditional = conditional # noise-conditional + self.centered = centered + self.scale_by_sigma = scale_by_sigma + + fir = fir + fir_kernel = fir_kernel + self.skip_rescale = skip_rescale = skip_rescale + self.resblock_type = resblock_type = resblock_type.lower() + self.progressive = progressive = progressive.lower() + self.progressive_input = progressive_input = progressive_input.lower() + self.embedding_type = embedding_type = embedding_type.lower() + init_scale = init_scale + assert progressive in ["none", "output_skip", "residual"] + assert progressive_input in ["none", "input_skip", "residual"] + assert embedding_type in ["fourier", "positional"] + combine_method = progressive_combine.lower() + combiner = functools.partial(Combine, method=combine_method) + + num_channels = 4 # x.real, x.imag, y.real, y.imag + self.output_layer = nn.Conv2d(num_channels, 2, 1) + + modules = [] + # timestep/noise_level embedding + if embedding_type == "fourier": + # Gaussian Fourier features embeddings. + modules.append( + layerspp.GaussianFourierProjection( + embedding_size=nf, scale=fourier_scale + ) + ) + embed_dim = 2 * nf + elif embedding_type == "positional": + embed_dim = nf + else: + raise ValueError(f"embedding type {embedding_type} unknown.") + + if conditional: + modules.append(nn.Linear(embed_dim, nf * 4)) + modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + modules.append(nn.Linear(nf * 4, nf * 4)) + modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + + AttnBlock = functools.partial( + layerspp.AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale + ) + + Upsample = functools.partial( + layerspp.Upsample, + with_conv=resamp_with_conv, + fir=fir, + fir_kernel=fir_kernel, + ) + + if progressive == "output_skip": + self.pyramid_upsample = layerspp.Upsample( + fir=fir, fir_kernel=fir_kernel, with_conv=False + ) + elif progressive == "residual": + pyramid_upsample = functools.partial( + layerspp.Upsample, fir=fir, fir_kernel=fir_kernel, with_conv=True + ) + + Downsample = functools.partial( + layerspp.Downsample, + with_conv=resamp_with_conv, + fir=fir, + fir_kernel=fir_kernel, + ) + + if progressive_input == "input_skip": + self.pyramid_downsample = layerspp.Downsample( + fir=fir, fir_kernel=fir_kernel, with_conv=False + ) + elif progressive_input == "residual": + pyramid_downsample = functools.partial( + layerspp.Downsample, fir=fir, fir_kernel=fir_kernel, with_conv=True + ) + + if resblock_type == "ddpm": + ResnetBlock = functools.partial( + ResnetBlockDDPM, + act=act, + dropout=dropout, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) + + elif resblock_type == "biggan": + ResnetBlock = functools.partial( + ResnetBlockBigGAN, + act=act, + dropout=dropout, + fir=fir, + fir_kernel=fir_kernel, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) + + else: + raise ValueError(f"resblock type {resblock_type} unrecognized.") + + # Downsampling block + + channels = num_channels + if progressive_input != "none": + input_pyramid_ch = channels + + modules.append(conv3x3(channels, nf)) + hs_c = [nf] + + in_ch = nf + for i_level in range(num_resolutions): + # Residual blocks for this resolution + for i_block in range(num_res_blocks): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + hs_c.append(in_ch) + + if i_level != num_resolutions - 1: + if resblock_type == "ddpm": + modules.append(Downsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(down=True, in_ch=in_ch)) + + if progressive_input == "input_skip": + modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) + if combine_method == "cat": + in_ch *= 2 + + elif progressive_input == "residual": + modules.append( + pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch) + ) + input_pyramid_ch = in_ch + + hs_c.append(in_ch) + + in_ch = hs_c[-1] + modules.append(ResnetBlock(in_ch=in_ch)) + modules.append(AttnBlock(channels=in_ch)) + modules.append(ResnetBlock(in_ch=in_ch)) + + pyramid_ch = 0 + # Upsampling block + for i_level in reversed(range(num_resolutions)): + for i_block in range( + num_res_blocks + 1 + ): # +1 blocks in upsampling because of skip connection from combiner (after downsampling) + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + + if progressive != "none": + if i_level == num_resolutions - 1: + if progressive == "output_skip": + modules.append( + nn.GroupNorm( + num_groups=min(in_ch // 4, 32), + num_channels=in_ch, + eps=1e-6, + ) + ) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == "residual": + modules.append( + nn.GroupNorm( + num_groups=min(in_ch // 4, 32), + num_channels=in_ch, + eps=1e-6, + ) + ) + modules.append(conv3x3(in_ch, in_ch, bias=True)) + pyramid_ch = in_ch + else: + raise ValueError(f"{progressive} is not a valid name.") + else: + if progressive == "output_skip": + modules.append( + nn.GroupNorm( + num_groups=min(in_ch // 4, 32), + num_channels=in_ch, + eps=1e-6, + ) + ) + modules.append( + conv3x3(in_ch, channels, bias=True, init_scale=init_scale) + ) + pyramid_ch = channels + elif progressive == "residual": + modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) + pyramid_ch = in_ch + else: + raise ValueError(f"{progressive} is not a valid name") + + if i_level != 0: + if resblock_type == "ddpm": + modules.append(Upsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(in_ch=in_ch, up=True)) + + assert not hs_c + + if progressive != "output_skip": + modules.append( + nn.GroupNorm( + num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6 + ) + ) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + + self.all_modules = nn.ModuleList(modules) + + def forward(self, x, time_cond): + # timestep/noise_level embedding; only for continuous training + modules = self.all_modules + m_idx = 0 + + # Convert real and imaginary parts of (x,y) into four channel dimensions + x = torch.cat( + ( + x[:, [0], :, :].real, + x[:, [0], :, :].imag, + x[:, [1], :, :].real, + x[:, [1], :, :].imag, + ), + dim=1, + ) + + if self.embedding_type == "fourier": + # Gaussian Fourier features embeddings. + used_sigmas = time_cond + temb = modules[m_idx](torch.log(used_sigmas)) + m_idx += 1 + + elif self.embedding_type == "positional": + # Sinusoidal positional embeddings. + timesteps = time_cond + used_sigmas = self.sigmas[time_cond.long()] + temb = layers.get_timestep_embedding(timesteps, self.nf) + + else: + raise ValueError(f"embedding type {self.embedding_type} unknown.") + + if self.conditional: + temb = modules[m_idx](temb) + m_idx += 1 + temb = modules[m_idx](self.act(temb)) + m_idx += 1 + else: + temb = None + + if not self.centered: + # If input data is in [0, 1] + x = 2 * x - 1.0 + + # Downsampling block + input_pyramid = None + if self.progressive_input != "none": + input_pyramid = x + + # Input layer: Conv2d: 4ch -> 128ch + hs = [modules[m_idx](x)] + m_idx += 1 + + # Down path in U-Net + for i_level in range(self.num_resolutions): + # Residual blocks for this resolution + for i_block in range(self.num_res_blocks): + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + # Attention layer (optional) + if ( + h.shape[-2] in self.attn_resolutions + ): # edit: check H dim (-2) not W dim (-1) + h = modules[m_idx](h) + m_idx += 1 + hs.append(h) + + # Downsampling + if i_level != self.num_resolutions - 1: + if self.resblock_type == "ddpm": + h = modules[m_idx](hs[-1]) + m_idx += 1 + else: + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + + if self.progressive_input == "input_skip": # Combine h with x + input_pyramid = self.pyramid_downsample(input_pyramid) + h = modules[m_idx](input_pyramid, h) + m_idx += 1 + + elif self.progressive_input == "residual": + input_pyramid = modules[m_idx](input_pyramid) + m_idx += 1 + if self.skip_rescale: + input_pyramid = (input_pyramid + h) / np.sqrt(2.0) + else: + input_pyramid = input_pyramid + h + h = input_pyramid + hs.append(h) + + h = hs[-1] # actualy equal to: h = h + h = modules[m_idx](h, temb) # ResNet block + m_idx += 1 + h = modules[m_idx](h) # Attention block + m_idx += 1 + h = modules[m_idx](h, temb) # ResNet block + m_idx += 1 + + pyramid = None + + # Upsampling block + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) + m_idx += 1 + + # edit: from -1 to -2 + if h.shape[-2] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + + if self.progressive != "none": + if i_level == self.num_resolutions - 1: + if self.progressive == "output_skip": + pyramid = self.act(modules[m_idx](h)) # GroupNorm + m_idx += 1 + pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4 + m_idx += 1 + elif self.progressive == "residual": + pyramid = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid = modules[m_idx](pyramid) + m_idx += 1 + else: + raise ValueError(f"{self.progressive} is not a valid name.") + else: + if self.progressive == "output_skip": + pyramid = self.pyramid_upsample(pyramid) # Upsample + pyramid_h = self.act(modules[m_idx](h)) # GroupNorm + m_idx += 1 + pyramid_h = modules[m_idx](pyramid_h) + m_idx += 1 + pyramid = pyramid + pyramid_h + elif self.progressive == "residual": + pyramid = modules[m_idx](pyramid) + m_idx += 1 + if self.skip_rescale: + pyramid = (pyramid + h) / np.sqrt(2.0) + else: + pyramid = pyramid + h + h = pyramid + else: + raise ValueError(f"{self.progressive} is not a valid name") + + # Upsampling Layer + if i_level != 0: + if self.resblock_type == "ddpm": + h = modules[m_idx](h) + m_idx += 1 + else: + h = modules[m_idx](h, temb) # Upspampling + m_idx += 1 + + assert not hs + + if self.progressive == "output_skip": + h = pyramid + else: + h = self.act(modules[m_idx](h)) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + + assert m_idx == len(modules), "Implementation error" + if self.scale_by_sigma: + used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) + h = h / used_sigmas + + # Convert back to complex number + h = self.output_layer(h) + h = torch.permute(h, (0, 2, 3, 1)).contiguous() + h = torch.view_as_complex(h)[:, None, :, :] + return h diff --git a/modules/sgmse/ncsnpp_utils/layers.py b/modules/sgmse/ncsnpp_utils/layers.py new file mode 100644 index 00000000..76bf8ac3 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/layers.py @@ -0,0 +1,800 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +"""Common layers for defining score networks. +""" +import math +import string +from functools import partial +import torch.nn as nn +import torch +import torch.nn.functional as F +import numpy as np +from .normalization import ConditionalInstanceNorm2dPlus + + +def get_act(config): + """Get activation functions from the config file.""" + + if config == "elu": + return nn.ELU() + elif config == "relu": + return nn.ReLU() + elif config == "lrelu": + return nn.LeakyReLU(negative_slope=0.2) + elif config == "swish": + return nn.SiLU() + else: + raise NotImplementedError("activation function does not exist!") + + +def ncsn_conv1x1( + in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=0 +): + """1x1 convolution. Same as NCSNv1/v2.""" + conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=bias, + dilation=dilation, + padding=padding, + ) + init_scale = 1e-10 if init_scale == 0 else init_scale + conv.weight.data *= init_scale + conv.bias.data *= init_scale + return conv + + +def variance_scaling( + scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu" +): + """Ported from JAX.""" + + def _compute_fans(shape, in_axis=1, out_axis=0): + receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + fan_in = shape[in_axis] * receptive_field_size + fan_out = shape[out_axis] * receptive_field_size + return fan_in, fan_out + + def init(shape, dtype=dtype, device=device): + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) + if mode == "fan_in": + denominator = fan_in + elif mode == "fan_out": + denominator = fan_out + elif mode == "fan_avg": + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError( + "invalid mode for variance scaling initializer: {}".format(mode) + ) + variance = scale / denominator + if distribution == "normal": + return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) + elif distribution == "uniform": + return ( + torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0 + ) * np.sqrt(3 * variance) + else: + raise ValueError("invalid distribution for variance scaling initializer") + + return init + + +def default_init(scale=1.0): + """The same initialization used in DDPM.""" + scale = 1e-10 if scale == 0 else scale + return variance_scaling(scale, "fan_avg", "uniform") + + +class Dense(nn.Module): + """Linear layer with `default_init`.""" + + def __init__(self): + super().__init__() + + +def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): + """1x1 convolution with DDPM initialization.""" + conv = nn.Conv2d( + in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias + ) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def ncsn_conv3x3( + in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1 +): + """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.""" + init_scale = 1e-10 if init_scale == 0 else init_scale + conv = nn.Conv2d( + in_planes, + out_planes, + stride=stride, + bias=bias, + dilation=dilation, + padding=padding, + kernel_size=3, + ) + conv.weight.data *= init_scale + conv.bias.data *= init_scale + return conv + + +def ddpm_conv3x3( + in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1 +): + """3x3 convolution with DDPM initialization.""" + conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + ########################################################################### + # Functions below are ported over from the NCSNv1/NCSNv2 codebase: + # https://github.com/ermongroup/ncsn + # https://github.com/ermongroup/ncsnv2 + ########################################################################### + + +class CRPBlock(nn.Module): + def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True): + super().__init__() + self.convs = nn.ModuleList() + for i in range(n_stages): + self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) + self.n_stages = n_stages + if maxpool: + self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) + else: + self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + + self.act = act + + def forward(self, x): + x = self.act(x) + path = x + for i in range(self.n_stages): + path = self.pool(path) + path = self.convs[i](path) + x = path + x + return x + + +class CondCRPBlock(nn.Module): + def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()): + super().__init__() + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + self.normalizer = normalizer + for i in range(n_stages): + self.norms.append(normalizer(features, num_classes, bias=True)) + self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) + + self.n_stages = n_stages + self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) + self.act = act + + def forward(self, x, y): + x = self.act(x) + path = x + for i in range(self.n_stages): + path = self.norms[i](path, y) + path = self.pool(path) + path = self.convs[i](path) + + x = path + x + return x + + +class RCUBlock(nn.Module): + def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()): + super().__init__() + + for i in range(n_blocks): + for j in range(n_stages): + setattr( + self, + "{}_{}_conv".format(i + 1, j + 1), + ncsn_conv3x3(features, features, stride=1, bias=False), + ) + + self.stride = 1 + self.n_blocks = n_blocks + self.n_stages = n_stages + self.act = act + + def forward(self, x): + for i in range(self.n_blocks): + residual = x + for j in range(self.n_stages): + x = self.act(x) + x = getattr(self, "{}_{}_conv".format(i + 1, j + 1))(x) + + x += residual + return x + + +class CondRCUBlock(nn.Module): + def __init__( + self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU() + ): + super().__init__() + + for i in range(n_blocks): + for j in range(n_stages): + setattr( + self, + "{}_{}_norm".format(i + 1, j + 1), + normalizer(features, num_classes, bias=True), + ) + setattr( + self, + "{}_{}_conv".format(i + 1, j + 1), + ncsn_conv3x3(features, features, stride=1, bias=False), + ) + + self.stride = 1 + self.n_blocks = n_blocks + self.n_stages = n_stages + self.act = act + self.normalizer = normalizer + + def forward(self, x, y): + for i in range(self.n_blocks): + residual = x + for j in range(self.n_stages): + x = getattr(self, "{}_{}_norm".format(i + 1, j + 1))(x, y) + x = self.act(x) + x = getattr(self, "{}_{}_conv".format(i + 1, j + 1))(x) + + x += residual + return x + + +class MSFBlock(nn.Module): + def __init__(self, in_planes, features): + super().__init__() + assert isinstance(in_planes, list) or isinstance(in_planes, tuple) + self.convs = nn.ModuleList() + self.features = features + + for i in range(len(in_planes)): + self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) + + def forward(self, xs, shape): + sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) + for i in range(len(self.convs)): + h = self.convs[i](xs[i]) + h = F.interpolate(h, size=shape, mode="bilinear", align_corners=True) + sums += h + return sums + + +class CondMSFBlock(nn.Module): + def __init__(self, in_planes, features, num_classes, normalizer): + super().__init__() + assert isinstance(in_planes, list) or isinstance(in_planes, tuple) + + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + self.features = features + self.normalizer = normalizer + + for i in range(len(in_planes)): + self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) + self.norms.append(normalizer(in_planes[i], num_classes, bias=True)) + + def forward(self, xs, y, shape): + sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) + for i in range(len(self.convs)): + h = self.norms[i](xs[i], y) + h = self.convs[i](h) + h = F.interpolate(h, size=shape, mode="bilinear", align_corners=True) + sums += h + return sums + + +class RefineBlock(nn.Module): + def __init__( + self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True + ): + super().__init__() + + assert isinstance(in_planes, tuple) or isinstance(in_planes, list) + self.n_blocks = n_blocks = len(in_planes) + + self.adapt_convs = nn.ModuleList() + for i in range(n_blocks): + self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act)) + + self.output_convs = RCUBlock(features, 3 if end else 1, 2, act) + + if not start: + self.msf = MSFBlock(in_planes, features) + + self.crp = CRPBlock(features, 2, act, maxpool=maxpool) + + def forward(self, xs, output_shape): + assert isinstance(xs, tuple) or isinstance(xs, list) + hs = [] + for i in range(len(xs)): + h = self.adapt_convs[i](xs[i]) + hs.append(h) + + if self.n_blocks > 1: + h = self.msf(hs, output_shape) + else: + h = hs[0] + + h = self.crp(h) + h = self.output_convs(h) + + return h + + +class CondRefineBlock(nn.Module): + def __init__( + self, + in_planes, + features, + num_classes, + normalizer, + act=nn.ReLU(), + start=False, + end=False, + ): + super().__init__() + + assert isinstance(in_planes, tuple) or isinstance(in_planes, list) + self.n_blocks = n_blocks = len(in_planes) + + self.adapt_convs = nn.ModuleList() + for i in range(n_blocks): + self.adapt_convs.append( + CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act) + ) + + self.output_convs = CondRCUBlock( + features, 3 if end else 1, 2, num_classes, normalizer, act + ) + + if not start: + self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer) + + self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act) + + def forward(self, xs, y, output_shape): + assert isinstance(xs, tuple) or isinstance(xs, list) + hs = [] + for i in range(len(xs)): + h = self.adapt_convs[i](xs[i], y) + hs.append(h) + + if self.n_blocks > 1: + h = self.msf(hs, y, output_shape) + else: + h = hs[0] + + h = self.crp(h, y) + h = self.output_convs(h, y) + + return h + + +class ConvMeanPool(nn.Module): + def __init__( + self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False + ): + super().__init__() + if not adjust_padding: + conv = nn.Conv2d( + input_dim, + output_dim, + kernel_size, + stride=1, + padding=kernel_size // 2, + bias=biases, + ) + self.conv = conv + else: + conv = nn.Conv2d( + input_dim, + output_dim, + kernel_size, + stride=1, + padding=kernel_size // 2, + bias=biases, + ) + + self.conv = nn.Sequential(nn.ZeroPad2d((1, 0, 1, 0)), conv) + + def forward(self, inputs): + output = self.conv(inputs) + output = ( + sum( + [ + output[:, :, ::2, ::2], + output[:, :, 1::2, ::2], + output[:, :, ::2, 1::2], + output[:, :, 1::2, 1::2], + ] + ) + / 4.0 + ) + return output + + +class MeanPoolConv(nn.Module): + def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): + super().__init__() + self.conv = nn.Conv2d( + input_dim, + output_dim, + kernel_size, + stride=1, + padding=kernel_size // 2, + bias=biases, + ) + + def forward(self, inputs): + output = inputs + output = ( + sum( + [ + output[:, :, ::2, ::2], + output[:, :, 1::2, ::2], + output[:, :, ::2, 1::2], + output[:, :, 1::2, 1::2], + ] + ) + / 4.0 + ) + return self.conv(output) + + +class UpsampleConv(nn.Module): + def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): + super().__init__() + self.conv = nn.Conv2d( + input_dim, + output_dim, + kernel_size, + stride=1, + padding=kernel_size // 2, + bias=biases, + ) + self.pixelshuffle = nn.PixelShuffle(upscale_factor=2) + + def forward(self, inputs): + output = inputs + output = torch.cat([output, output, output, output], dim=1) + output = self.pixelshuffle(output) + return self.conv(output) + + +class ConditionalResidualBlock(nn.Module): + def __init__( + self, + input_dim, + output_dim, + num_classes, + resample=1, + act=nn.ELU(), + normalization=ConditionalInstanceNorm2dPlus, + adjust_padding=False, + dilation=None, + ): + super().__init__() + self.non_linearity = act + self.input_dim = input_dim + self.output_dim = output_dim + self.resample = resample + self.normalization = normalization + if resample == "down": + if dilation > 1: + self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) + self.normalize2 = normalization(input_dim, num_classes) + self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) + conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) + else: + self.conv1 = ncsn_conv3x3(input_dim, input_dim) + self.normalize2 = normalization(input_dim, num_classes) + self.conv2 = ConvMeanPool( + input_dim, output_dim, 3, adjust_padding=adjust_padding + ) + conv_shortcut = partial( + ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding + ) + + elif resample is None: + if dilation > 1: + conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) + self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) + self.normalize2 = normalization(output_dim, num_classes) + self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) + else: + conv_shortcut = nn.Conv2d + self.conv1 = ncsn_conv3x3(input_dim, output_dim) + self.normalize2 = normalization(output_dim, num_classes) + self.conv2 = ncsn_conv3x3(output_dim, output_dim) + else: + raise Exception("invalid resample value") + + if output_dim != input_dim or resample is not None: + self.shortcut = conv_shortcut(input_dim, output_dim) + + self.normalize1 = normalization(input_dim, num_classes) + + def forward(self, x, y): + output = self.normalize1(x, y) + output = self.non_linearity(output) + output = self.conv1(output) + output = self.normalize2(output, y) + output = self.non_linearity(output) + output = self.conv2(output) + + if self.output_dim == self.input_dim and self.resample is None: + shortcut = x + else: + shortcut = self.shortcut(x) + + return shortcut + output + + +class ResidualBlock(nn.Module): + def __init__( + self, + input_dim, + output_dim, + resample=None, + act=nn.ELU(), + normalization=nn.InstanceNorm2d, + adjust_padding=False, + dilation=1, + ): + super().__init__() + self.non_linearity = act + self.input_dim = input_dim + self.output_dim = output_dim + self.resample = resample + self.normalization = normalization + if resample == "down": + if dilation > 1: + self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) + self.normalize2 = normalization(input_dim) + self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) + conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) + else: + self.conv1 = ncsn_conv3x3(input_dim, input_dim) + self.normalize2 = normalization(input_dim) + self.conv2 = ConvMeanPool( + input_dim, output_dim, 3, adjust_padding=adjust_padding + ) + conv_shortcut = partial( + ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding + ) + + elif resample is None: + if dilation > 1: + conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) + self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) + self.normalize2 = normalization(output_dim) + self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) + else: + # conv_shortcut = nn.Conv2d ### Something wierd here. + conv_shortcut = partial(ncsn_conv1x1) + self.conv1 = ncsn_conv3x3(input_dim, output_dim) + self.normalize2 = normalization(output_dim) + self.conv2 = ncsn_conv3x3(output_dim, output_dim) + else: + raise Exception("invalid resample value") + + if output_dim != input_dim or resample is not None: + self.shortcut = conv_shortcut(input_dim, output_dim) + + self.normalize1 = normalization(input_dim) + + def forward(self, x): + output = self.normalize1(x) + output = self.non_linearity(output) + output = self.conv1(output) + output = self.normalize2(output) + output = self.non_linearity(output) + output = self.conv2(output) + + if self.output_dim == self.input_dim and self.resample is None: + shortcut = x + else: + shortcut = self.shortcut(x) + + return shortcut + output + + +########################################################################### +# Functions below are ported over from the DDPM codebase: +# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py +########################################################################### + + +def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embedding_dim // 2 + # magic number 10000 is from transformers + emb = math.log(max_positions) / (half_dim - 1) + # emb = math.log(2.) / (half_dim - 1) + emb = torch.exp( + torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb + ) + # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] + # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode="constant") + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +def _einsum(a, b, c, x, y): + einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) + return torch.einsum(einsum_str, x, y) + + +def contract_inner(x, y): + """tensordot(x, y, 1).""" + x_chars = list(string.ascii_lowercase[: len(x.shape)]) + y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)]) + y_chars[0] = x_chars[-1] # first axis of y and last of x get summed + out_chars = x_chars[:-1] + y_chars[1:] + return _einsum(x_chars, y_chars, out_chars, x, y) + + +class NIN(nn.Module): + def __init__(self, in_dim, num_units, init_scale=0.1): + super().__init__() + self.W = nn.Parameter( + default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True + ) + self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + y = contract_inner(x, self.W) + self.b + return y.permute(0, 3, 1, 2) + + +class AttnBlock(nn.Module): + """Channel-wise self-attention block.""" + + def __init__(self, channels): + super().__init__() + self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels, init_scale=0.0) + + def forward(self, x): + B, C, H, W = x.shape + h = self.GroupNorm_0(x) + q = self.NIN_0(h) + k = self.NIN_1(h) + v = self.NIN_2(h) + + w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) + w = torch.reshape(w, (B, H, W, H * W)) + w = F.softmax(w, dim=-1) + w = torch.reshape(w, (B, H, W, H, W)) + h = torch.einsum("bhwij,bcij->bchw", w, v) + h = self.NIN_3(h) + return x + h + + +class Upsample(nn.Module): + def __init__(self, channels, with_conv=False): + super().__init__() + if with_conv: + self.Conv_0 = ddpm_conv3x3(channels, channels) + self.with_conv = with_conv + + def forward(self, x): + B, C, H, W = x.shape + h = F.interpolate(x, (H * 2, W * 2), mode="nearest") + if self.with_conv: + h = self.Conv_0(h) + return h + + +class Downsample(nn.Module): + def __init__(self, channels, with_conv=False): + super().__init__() + if with_conv: + self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0) + self.with_conv = with_conv + + def forward(self, x): + B, C, H, W = x.shape + # Emulate 'SAME' padding + if self.with_conv: + x = F.pad(x, (0, 1, 0, 1)) + x = self.Conv_0(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) + + assert x.shape == (B, C, H // 2, W // 2) + return x + + +class ResnetBlockDDPM(nn.Module): + """The ResNet Blocks used in DDPM.""" + + def __init__( + self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1 + ): + super().__init__() + if out_ch is None: + out_ch = in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6) + self.act = act + self.Conv_0 = ddpm_conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.0) + if in_ch != out_ch: + if conv_shortcut: + self.Conv_2 = ddpm_conv3x3(in_ch, out_ch) + else: + self.NIN_0 = NIN(in_ch, out_ch) + self.out_ch = out_ch + self.in_ch = in_ch + self.conv_shortcut = conv_shortcut + + def forward(self, x, temb=None): + B, C, H, W = x.shape + assert C == self.in_ch + out_ch = self.out_ch if self.out_ch else self.in_ch + h = self.act(self.GroupNorm_0(x)) + h = self.Conv_0(h) + # Add bias to each feature map conditioned on the time embedding + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + if C != out_ch: + if self.conv_shortcut: + x = self.Conv_2(x) + else: + x = self.NIN_0(x) + return x + h diff --git a/modules/sgmse/ncsnpp_utils/layerspp.py b/modules/sgmse/ncsnpp_utils/layerspp.py new file mode 100644 index 00000000..793b7e24 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/layerspp.py @@ -0,0 +1,323 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +"""Layers for defining NCSN++. +""" +from . import layers +from . import up_or_down_sampling +import torch.nn as nn +import torch +import torch.nn.functional as F +import numpy as np + +conv1x1 = layers.ddpm_conv1x1 +conv3x3 = layers.ddpm_conv3x3 +NIN = layers.NIN +default_init = layers.default_init + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +class Combine(nn.Module): + """Combine information from skip connections.""" + + def __init__(self, dim1, dim2, method="cat"): + super().__init__() + self.Conv_0 = conv1x1(dim1, dim2) + self.method = method + + def forward(self, x, y): + h = self.Conv_0(x) + if self.method == "cat": + return torch.cat([h, y], dim=1) + elif self.method == "sum": + return h + y + else: + raise ValueError(f"Method {self.method} not recognized.") + + +class AttnBlockpp(nn.Module): + """Channel-wise self-attention block. Modified from DDPM.""" + + def __init__(self, channels, skip_rescale=False, init_scale=0.0): + super().__init__() + self.GroupNorm_0 = nn.GroupNorm( + num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6 + ) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels, init_scale=init_scale) + self.skip_rescale = skip_rescale + + def forward(self, x): + B, C, H, W = x.shape + h = self.GroupNorm_0(x) + q = self.NIN_0(h) + k = self.NIN_1(h) + v = self.NIN_2(h) + + w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) + w = torch.reshape(w, (B, H, W, H * W)) + w = F.softmax(w, dim=-1) + w = torch.reshape(w, (B, H, W, H, W)) + h = torch.einsum("bhwij,bcij->bchw", w, v) + h = self.NIN_3(h) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class Upsample(nn.Module): + def __init__( + self, + in_ch=None, + out_ch=None, + with_conv=False, + fir=False, + fir_kernel=(1, 3, 3, 1), + ): + super().__init__() + out_ch = out_ch if out_ch else in_ch + if not fir: + if with_conv: + self.Conv_0 = conv3x3(in_ch, out_ch) + else: + if with_conv: + self.Conv2d_0 = up_or_down_sampling.Conv2d( + in_ch, + out_ch, + kernel=3, + up=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) + self.fir = fir + self.with_conv = with_conv + self.fir_kernel = fir_kernel + self.out_ch = out_ch + + def forward(self, x): + B, C, H, W = x.shape + if not self.fir: + h = F.interpolate(x, (H * 2, W * 2), "nearest") + if self.with_conv: + h = self.Conv_0(h) + else: + if not self.with_conv: + h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) + else: + h = self.Conv2d_0(x) + + return h + + +class Downsample(nn.Module): + def __init__( + self, + in_ch=None, + out_ch=None, + with_conv=False, + fir=False, + fir_kernel=(1, 3, 3, 1), + ): + super().__init__() + out_ch = out_ch if out_ch else in_ch + if not fir: + if with_conv: + self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) + else: + if with_conv: + self.Conv2d_0 = up_or_down_sampling.Conv2d( + in_ch, + out_ch, + kernel=3, + down=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) + self.fir = fir + self.fir_kernel = fir_kernel + self.with_conv = with_conv + self.out_ch = out_ch + + def forward(self, x): + B, C, H, W = x.shape + if not self.fir: + if self.with_conv: + x = F.pad(x, (0, 1, 0, 1)) + x = self.Conv_0(x) + else: + x = F.avg_pool2d(x, 2, stride=2) + else: + if not self.with_conv: + x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) + else: + x = self.Conv2d_0(x) + + return x + + +class ResnetBlockDDPMpp(nn.Module): + """ResBlock adapted from DDPM.""" + + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + conv_shortcut=False, + dropout=0.1, + skip_rescale=False, + init_scale=0.0, + ): + super().__init__() + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm( + num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6 + ) + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) + nn.init.zeros_(self.Dense_0.bias) + self.GroupNorm_1 = nn.GroupNorm( + num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6 + ) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch: + if conv_shortcut: + self.Conv_2 = conv3x3(in_ch, out_ch) + else: + self.NIN_0 = NIN(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.out_ch = out_ch + self.conv_shortcut = conv_shortcut + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + h = self.Conv_0(h) + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + if x.shape[1] != self.out_ch: + if self.conv_shortcut: + x = self.Conv_2(x) + else: + x = self.NIN_0(x) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class ResnetBlockBigGANpp(nn.Module): + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + up=False, + down=False, + dropout=0.1, + fir=False, + fir_kernel=(1, 3, 3, 1), + skip_rescale=True, + init_scale=0.0, + ): + super().__init__() + + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm( + num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6 + ) + self.up = up + self.down = down + self.fir = fir + self.fir_kernel = fir_kernel + + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm( + num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6 + ) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch or up or down: + self.Conv_2 = conv1x1(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.in_ch = in_ch + self.out_ch = out_ch + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + + if self.up: + if self.fir: + h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) + x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) + else: + h = up_or_down_sampling.naive_upsample_2d(h, factor=2) + x = up_or_down_sampling.naive_upsample_2d(x, factor=2) + elif self.down: + if self.fir: + h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) + x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) + else: + h = up_or_down_sampling.naive_downsample_2d(h, factor=2) + x = up_or_down_sampling.naive_downsample_2d(x, factor=2) + + h = self.Conv_0(h) + # Add bias to each feature map conditioned on the time embedding + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + + if self.in_ch != self.out_ch or self.up or self.down: + x = self.Conv_2(x) + + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) diff --git a/modules/sgmse/ncsnpp_utils/normalization.py b/modules/sgmse/ncsnpp_utils/normalization.py new file mode 100644 index 00000000..fcc4707e --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/normalization.py @@ -0,0 +1,243 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Normalization layers.""" +import torch.nn as nn +import torch +import functools + + +def get_normalization(config, conditional=False): + """Obtain normalization modules from the config file.""" + norm = config.model.normalization + if conditional: + if norm == "InstanceNorm++": + return functools.partial( + ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes + ) + else: + raise NotImplementedError(f"{norm} not implemented yet.") + else: + if norm == "InstanceNorm": + return nn.InstanceNorm2d + elif norm == "InstanceNorm++": + return InstanceNorm2dPlus + elif norm == "VarianceNorm": + return VarianceNorm2d + elif norm == "GroupNorm": + return nn.GroupNorm + else: + raise ValueError("Unknown normalization: %s" % norm) + + +class ConditionalBatchNorm2d(nn.Module): + def __init__(self, num_features, num_classes, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + self.bn = nn.BatchNorm2d(num_features, affine=False) + if self.bias: + self.embed = nn.Embedding(num_classes, num_features * 2) + self.embed.weight.data[ + :, :num_features + ].uniform_() # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 + else: + self.embed = nn.Embedding(num_classes, num_features) + self.embed.weight.data.uniform_() + + def forward(self, x, y): + out = self.bn(x) + if self.bias: + gamma, beta = self.embed(y).chunk(2, dim=1) + out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view( + -1, self.num_features, 1, 1 + ) + else: + gamma = self.embed(y) + out = gamma.view(-1, self.num_features, 1, 1) * out + return out + + +class ConditionalInstanceNorm2d(nn.Module): + def __init__(self, num_features, num_classes, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + self.instance_norm = nn.InstanceNorm2d( + num_features, affine=False, track_running_stats=False + ) + if bias: + self.embed = nn.Embedding(num_classes, num_features * 2) + self.embed.weight.data[ + :, :num_features + ].uniform_() # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 + else: + self.embed = nn.Embedding(num_classes, num_features) + self.embed.weight.data.uniform_() + + def forward(self, x, y): + h = self.instance_norm(x) + if self.bias: + gamma, beta = self.embed(y).chunk(2, dim=-1) + out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view( + -1, self.num_features, 1, 1 + ) + else: + gamma = self.embed(y) + out = gamma.view(-1, self.num_features, 1, 1) * h + return out + + +class ConditionalVarianceNorm2d(nn.Module): + def __init__(self, num_features, num_classes, bias=False): + super().__init__() + self.num_features = num_features + self.bias = bias + self.embed = nn.Embedding(num_classes, num_features) + self.embed.weight.data.normal_(1, 0.02) + + def forward(self, x, y): + vars = torch.var(x, dim=(2, 3), keepdim=True) + h = x / torch.sqrt(vars + 1e-5) + + gamma = self.embed(y) + out = gamma.view(-1, self.num_features, 1, 1) * h + return out + + +class VarianceNorm2d(nn.Module): + def __init__(self, num_features, bias=False): + super().__init__() + self.num_features = num_features + self.bias = bias + self.alpha = nn.Parameter(torch.zeros(num_features)) + self.alpha.data.normal_(1, 0.02) + + def forward(self, x): + vars = torch.var(x, dim=(2, 3), keepdim=True) + h = x / torch.sqrt(vars + 1e-5) + + out = self.alpha.view(-1, self.num_features, 1, 1) * h + return out + + +class ConditionalNoneNorm2d(nn.Module): + def __init__(self, num_features, num_classes, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + if bias: + self.embed = nn.Embedding(num_classes, num_features * 2) + self.embed.weight.data[ + :, :num_features + ].uniform_() # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 + else: + self.embed = nn.Embedding(num_classes, num_features) + self.embed.weight.data.uniform_() + + def forward(self, x, y): + if self.bias: + gamma, beta = self.embed(y).chunk(2, dim=-1) + out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view( + -1, self.num_features, 1, 1 + ) + else: + gamma = self.embed(y) + out = gamma.view(-1, self.num_features, 1, 1) * x + return out + + +class NoneNorm2d(nn.Module): + def __init__(self, num_features, bias=True): + super().__init__() + + def forward(self, x): + return x + + +class InstanceNorm2dPlus(nn.Module): + def __init__(self, num_features, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + self.instance_norm = nn.InstanceNorm2d( + num_features, affine=False, track_running_stats=False + ) + self.alpha = nn.Parameter(torch.zeros(num_features)) + self.gamma = nn.Parameter(torch.zeros(num_features)) + self.alpha.data.normal_(1, 0.02) + self.gamma.data.normal_(1, 0.02) + if bias: + self.beta = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + means = torch.mean(x, dim=(2, 3)) + m = torch.mean(means, dim=-1, keepdim=True) + v = torch.var(means, dim=-1, keepdim=True) + means = (means - m) / (torch.sqrt(v + 1e-5)) + h = self.instance_norm(x) + + if self.bias: + h = h + means[..., None, None] * self.alpha[..., None, None] + out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view( + -1, self.num_features, 1, 1 + ) + else: + h = h + means[..., None, None] * self.alpha[..., None, None] + out = self.gamma.view(-1, self.num_features, 1, 1) * h + return out + + +class ConditionalInstanceNorm2dPlus(nn.Module): + def __init__(self, num_features, num_classes, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + self.instance_norm = nn.InstanceNorm2d( + num_features, affine=False, track_running_stats=False + ) + if bias: + self.embed = nn.Embedding(num_classes, num_features * 3) + self.embed.weight.data[:, : 2 * num_features].normal_( + 1, 0.02 + ) # Initialise scale at N(1, 0.02) + self.embed.weight.data[ + :, 2 * num_features : + ].zero_() # Initialise bias at 0 + else: + self.embed = nn.Embedding(num_classes, 2 * num_features) + self.embed.weight.data.normal_(1, 0.02) + + def forward(self, x, y): + means = torch.mean(x, dim=(2, 3)) + m = torch.mean(means, dim=-1, keepdim=True) + v = torch.var(means, dim=-1, keepdim=True) + means = (means - m) / (torch.sqrt(v + 1e-5)) + h = self.instance_norm(x) + + if self.bias: + gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) + h = h + means[..., None, None] * alpha[..., None, None] + out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view( + -1, self.num_features, 1, 1 + ) + else: + gamma, alpha = self.embed(y).chunk(2, dim=-1) + h = h + means[..., None, None] * alpha[..., None, None] + out = gamma.view(-1, self.num_features, 1, 1) * h + return out diff --git a/modules/sgmse/ncsnpp_utils/op/__init__.py b/modules/sgmse/ncsnpp_utils/op/__init__.py new file mode 100644 index 00000000..d0918d92 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/modules/sgmse/ncsnpp_utils/op/fused_act.py b/modules/sgmse/ncsnpp_utils/op/fused_act.py new file mode 100644 index 00000000..9f6cd311 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/op/fused_act.py @@ -0,0 +1,97 @@ +import os + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +fused = load( + "fused", + sources=[ + os.path.join(module_path, "fused_bias_act.cpp"), + os.path.join(module_path, "fused_bias_act_kernel.cu"), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + (out,) = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + (out,) = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + if input.device.type == "cpu": + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 + ) + * scale + ) + + else: + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/modules/sgmse/ncsnpp_utils/op/fused_bias_act.cpp b/modules/sgmse/ncsnpp_utils/op/fused_bias_act.cpp new file mode 100644 index 00000000..a0543187 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/modules/sgmse/ncsnpp_utils/op/fused_bias_act_kernel.cu b/modules/sgmse/ncsnpp_utils/op/fused_bias_act_kernel.cu new file mode 100644 index 00000000..8d2f03c7 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/modules/sgmse/ncsnpp_utils/op/upfirdn2d.cpp b/modules/sgmse/ncsnpp_utils/op/upfirdn2d.cpp new file mode 100644 index 00000000..b07aa205 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/modules/sgmse/ncsnpp_utils/op/upfirdn2d.py b/modules/sgmse/ncsnpp_utils/op/upfirdn2d.py new file mode 100644 index 00000000..e18039c9 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/op/upfirdn2d.py @@ -0,0 +1,200 @@ +import os + +import torch +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + "upfirdn2d", + sources=[ + os.path.join(module_path, "upfirdn2d.cpp"), + os.path.join(module_path, "upfirdn2d_kernel.cu"), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + (kernel,) = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == "cpu": + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + else: + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/modules/sgmse/ncsnpp_utils/op/upfirdn2d_kernel.cu b/modules/sgmse/ncsnpp_utils/op/upfirdn2d_kernel.cu new file mode 100644 index 00000000..ed3eea30 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/op/upfirdn2d_kernel.cu @@ -0,0 +1,369 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} \ No newline at end of file diff --git a/modules/sgmse/ncsnpp_utils/up_or_down_sampling.py b/modules/sgmse/ncsnpp_utils/up_or_down_sampling.py new file mode 100644 index 00000000..5d59071f --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/up_or_down_sampling.py @@ -0,0 +1,273 @@ +"""Layers used for up-sampling or down-sampling images. + +Many functions are ported from https://github.com/NVlabs/stylegan2. +""" + +import torch.nn as nn +import torch +import torch.nn.functional as F +import numpy as np +from .op import upfirdn2d + + +# Function ported from StyleGAN2 +def get_weight(module, shape, weight_var="weight", kernel_init=None): + """Get/create weight tensor for a convolution or fully-connected layer.""" + + return module.param(weight_var, kernel_init, shape) + + +class Conv2d(nn.Module): + """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" + + def __init__( + self, + in_ch, + out_ch, + kernel, + up=False, + down=False, + resample_kernel=(1, 3, 3, 1), + use_bias=True, + kernel_init=None, + ): + super().__init__() + assert not (up and down) + assert kernel >= 1 and kernel % 2 == 1 + self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) + if kernel_init is not None: + self.weight.data = kernel_init(self.weight.data.shape) + if use_bias: + self.bias = nn.Parameter(torch.zeros(out_ch)) + + self.up = up + self.down = down + self.resample_kernel = resample_kernel + self.kernel = kernel + self.use_bias = use_bias + + def forward(self, x): + if self.up: + x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) + elif self.down: + x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) + else: + x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) + + if self.use_bias: + x = x + self.bias.reshape(1, -1, 1, 1) + + return x + + +def naive_upsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H, 1, W, 1)) + x = x.repeat(1, 1, 1, factor, 1, factor) + return torch.reshape(x, (-1, C, H * factor, W * factor)) + + +def naive_downsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) + return torch.mean(x, dim=(3, 5)) + + +def upsample_conv_2d(x, w, k=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. + + Padding is performed only once at the beginning, not between the + operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Check weight shape. + assert len(w.shape) == 4 + convH = w.shape[2] + convW = w.shape[3] + inC = w.shape[1] + outC = w.shape[0] + + assert convW == convH + + # Setup filter kernel. + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = (k.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ( + (_shape(x, 2) - 1) * factor + convH, + (_shape(x, 3) - 1) * factor + convW, + ) + output_padding = ( + output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, + output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = _shape(x, 1) // inC + + # Transpose weights. + w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) + w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) + + x = F.conv_transpose2d( + x, w, stride=stride, output_padding=output_padding, padding=0 + ) + ## Original TF code. + # x = tf.nn.conv2d_transpose( + # x, + # w, + # output_shape=output_shape, + # strides=stride, + # padding='VALID', + # data_format=data_format) + ## JAX equivalent + + return upfirdn2d( + x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1) + ) + + +def conv_downsample_2d(x, w, k=None, factor=2, gain=1): + """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + _outC, _inC, convH, convW = w.shape + assert convW == convH + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = (k.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) + return F.conv2d(x, w, stride=s, padding=0) + + +def _setup_kernel(k): + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + assert k.ndim == 2 + assert k.shape[0] == k.shape[1] + return k + + +def _shape(x, dim): + return x.shape[dim] + + +def upsample_2d(x, k=None, factor=2, gain=1): + r"""Upsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and upsamples each image with the given filter. The filter is normalized so + that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the upsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = k.shape[0] - factor + return upfirdn2d( + x, + torch.tensor(k, device=x.device), + up=factor, + pad=((p + 1) // 2 + factor - 1, p // 2), + ) + + +def downsample_2d(x, k=None, factor=2, gain=1): + r"""Downsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and downsamples each image with the given filter. The filter is normalized + so that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the downsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = k.shape[0] - factor + return upfirdn2d( + x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2) + ) diff --git a/modules/sgmse/ncsnpp_utils/utils.py b/modules/sgmse/ncsnpp_utils/utils.py new file mode 100644 index 00000000..38333da6 --- /dev/null +++ b/modules/sgmse/ncsnpp_utils/utils.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All functions and modules related to model definition. +""" + +import torch + +import numpy as np +from ...sdes import OUVESDE, OUVPSDE + + +_MODELS = {} + + +def register_model(cls=None, *, name=None): + """A decorator for registering model classes.""" + + def _register(cls): + if name is None: + local_name = cls.__name__ + else: + local_name = name + if local_name in _MODELS: + raise ValueError(f"Already registered model with name: {local_name}") + _MODELS[local_name] = cls + return cls + + if cls is None: + return _register + else: + return _register(cls) + + +def get_model(name): + return _MODELS[name] + + +def get_sigmas(sigma_min, sigma_max, num_scales): + """Get sigmas --- the set of noise levels for SMLD from config files. + Args: + config: A ConfigDict object parsed from the config file + Returns: + sigmas: a jax numpy arrary of noise levels + """ + sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales)) + + return sigmas + + +def get_ddpm_params(config): + """Get betas and alphas --- parameters used in the original DDPM paper.""" + num_diffusion_timesteps = 1000 + # parameters need to be adapted if number of time steps differs from 1000 + beta_start = config.model.beta_min / config.model.num_scales + beta_end = config.model.beta_max / config.model.num_scales + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_1m_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + + return { + "betas": betas, + "alphas": alphas, + "alphas_cumprod": alphas_cumprod, + "sqrt_alphas_cumprod": sqrt_alphas_cumprod, + "sqrt_1m_alphas_cumprod": sqrt_1m_alphas_cumprod, + "beta_min": beta_start * (num_diffusion_timesteps - 1), + "beta_max": beta_end * (num_diffusion_timesteps - 1), + "num_diffusion_timesteps": num_diffusion_timesteps, + } + + +def create_model(config): + """Create the score model.""" + model_name = config.model.name + score_model = get_model(model_name)(config) + score_model = score_model.to(config.device) + score_model = torch.nn.DataParallel(score_model) + return score_model + + +def get_model_fn(model, train=False): + """Create a function to give the output of the score-based model. + + Args: + model: The score model. + train: `True` for training and `False` for evaluation. + + Returns: + A model function. + """ + + def model_fn(x, labels): + """Compute the output of the score-based model. + + Args: + x: A mini-batch of input data. + labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently + for different models. + + Returns: + A tuple of (model output, new mutable states) + """ + if not train: + model.eval() + return model(x, labels) + else: + model.train() + return model(x, labels) + + return model_fn + + +def get_score_fn(sde, model, train=False, continuous=False): + """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. + + Args: + sde: An `sde_lib.SDE` object that represents the forward SDE. + model: A score model. + train: `True` for training and `False` for evaluation. + continuous: If `True`, the score-based model is expected to directly take continuous time steps. + + Returns: + A score function. + """ + model_fn = get_model_fn(model, train=train) + + if isinstance(sde, OUVPSDE): + + def score_fn(x, t): + # Scale neural network output by standard deviation and flip sign + if continuous: + # For VP-trained models, t=0 corresponds to the lowest noise level + # The maximum value of time embedding is assumed to 999 for + # continuously-trained models. + labels = t * 999 + score = model_fn(x, labels) + std = sde.marginal_prob(torch.zeros_like(x), t)[1] + else: + # For VP-trained models, t=0 corresponds to the lowest noise level + labels = t * (sde.N - 1) + score = model_fn(x, labels) + std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] + + score = -score / std[:, None, None, None] + return score + + elif isinstance(sde, OUVESDE): + + def score_fn(x, t): + if continuous: + labels = sde.marginal_prob(torch.zeros_like(x), t)[1] + else: + # For VE-trained models, t=0 corresponds to the highest noise level + labels = sde.T - t + labels *= sde.N - 1 + labels = torch.round(labels).long() + + score = model_fn(x, labels) + return score + + else: + raise NotImplementedError( + f"SDE class {sde.__class__.__name__} not yet supported." + ) + + return score_fn + + +def to_flattened_numpy(x): + """Flatten a torch tensor `x` and convert it to numpy.""" + return x.detach().cpu().numpy().reshape((-1,)) + + +def from_flattened_numpy(x, shape): + """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" + return torch.from_numpy(x.reshape(shape)) diff --git a/modules/sgmse/sampling/__init__.py b/modules/sgmse/sampling/__init__.py new file mode 100644 index 00000000..3248dc62 --- /dev/null +++ b/modules/sgmse/sampling/__init__.py @@ -0,0 +1,169 @@ +# Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py +"""Various sampling methods.""" +from scipy import integrate +import torch + +from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor +from .correctors import Corrector, CorrectorRegistry + + +__all__ = [ + "PredictorRegistry", + "CorrectorRegistry", + "Predictor", + "Corrector", + "get_sampler", +] + + +def to_flattened_numpy(x): + """Flatten a torch tensor `x` and convert it to numpy.""" + return x.detach().cpu().numpy().reshape((-1,)) + + +def from_flattened_numpy(x, shape): + """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" + return torch.from_numpy(x.reshape(shape)) + + +def get_pc_sampler( + predictor_name, + corrector_name, + sde, + score_fn, + y, + denoise=True, + eps=3e-2, + snr=0.1, + corrector_steps=1, + probability_flow: bool = False, + intermediate=False, + **kwargs +): + """Create a Predictor-Corrector (PC) sampler. + + Args: + predictor_name: The name of a registered `sampling.Predictor`. + corrector_name: The name of a registered `sampling.Corrector`. + sde: An `sdes.SDE` object representing the forward SDE. + score_fn: A function (typically learned model) that predicts the score. + y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on. + denoise: If `True`, add one-step denoising to the final samples. + eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues. + snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`. + N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default. + + Returns: + A sampling function that returns samples and the number of function evaluations during sampling. + """ + predictor_cls = PredictorRegistry.get_by_name(predictor_name) + corrector_cls = CorrectorRegistry.get_by_name(corrector_name) + predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow) + corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps) + + def pc_sampler(): + """The PC sampler function.""" + with torch.no_grad(): + xt = sde.prior_sampling(y.shape, y).to(y.device) + timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device) + for i in range(sde.N): + t = timesteps[i] + vec_t = torch.ones(y.shape[0], device=y.device) * t + xt, xt_mean = corrector.update_fn(xt, vec_t, y) + xt, xt_mean = predictor.update_fn(xt, vec_t, y) + x_result = xt_mean if denoise else xt + ns = sde.N * (corrector.n_steps + 1) + return x_result, ns + + return pc_sampler + + +def get_ode_sampler( + sde, + score_fn, + y, + inverse_scaler=None, + denoise=True, + rtol=1e-5, + atol=1e-5, + method="RK45", + eps=3e-2, + device="cuda", + **kwargs +): + """Probability flow ODE sampler with the black-box ODE solver. + + Args: + sde: An `sdes.SDE` object representing the forward SDE. + score_fn: A function (typically learned model) that predicts the score. + y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on. + inverse_scaler: The inverse data normalizer. + denoise: If `True`, add one-step denoising to final samples. + rtol: A `float` number. The relative tolerance level of the ODE solver. + atol: A `float` number. The absolute tolerance level of the ODE solver. + method: A `str`. The algorithm used for the black-box ODE solver. + See the documentation of `scipy.integrate.solve_ivp`. + eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability. + device: PyTorch device. + + Returns: + A sampling function that returns samples and the number of function evaluations during sampling. + """ + predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False) + rsde = sde.reverse(score_fn, probability_flow=True) + + def denoise_update_fn(x): + vec_eps = torch.ones(x.shape[0], device=x.device) * eps + _, x = predictor.update_fn(x, vec_eps, y) + return x + + def drift_fn(x, t, y): + """Get the drift function of the reverse-time SDE.""" + return rsde.sde(x, t, y)[0] + + def ode_sampler(z=None, **kwargs): + """The probability flow ODE sampler with black-box ODE solver. + + Args: + model: A score model. + z: If present, generate samples from latent code `z`. + Returns: + samples, number of function evaluations. + """ + with torch.no_grad(): + # If not represent, sample the latent code from the prior distibution of the SDE. + x = sde.prior_sampling(y.shape, y).to(device) + + def ode_func(t, x): + x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64) + vec_t = torch.ones(y.shape[0], device=x.device) * t + drift = drift_fn(x, vec_t, y) + return to_flattened_numpy(drift) + + # Black-box ODE solver for the probability flow ODE + solution = integrate.solve_ivp( + ode_func, + (sde.T, eps), + to_flattened_numpy(x), + rtol=rtol, + atol=atol, + method=method, + **kwargs + ) + nfe = solution.nfev + x = ( + torch.tensor(solution.y[:, -1]) + .reshape(y.shape) + .to(device) + .type(torch.complex64) + ) + + # Denoising is equivalent to running one predictor step without adding noise + if denoise: + x = denoise_update_fn(x) + + if inverse_scaler is not None: + x = inverse_scaler(x) + return x, nfe + + return ode_sampler diff --git a/modules/sgmse/sampling/correctors.py b/modules/sgmse/sampling/correctors.py new file mode 100644 index 00000000..4b995f4b --- /dev/null +++ b/modules/sgmse/sampling/correctors.py @@ -0,0 +1,99 @@ +import abc +import torch + +from modules.sgmse import sdes +from utils.sgmse_util.registry import Registry + + +CorrectorRegistry = Registry("Corrector") + + +class Corrector(abc.ABC): + """The abstract class for a corrector algorithm.""" + + def __init__(self, sde, score_fn, snr, n_steps): + super().__init__() + self.rsde = sde.reverse(score_fn) + self.score_fn = score_fn + self.snr = snr + self.n_steps = n_steps + + @abc.abstractmethod + def update_fn(self, x, t, *args): + """One update of the corrector. + + Args: + x: A PyTorch tensor representing the current state + t: A PyTorch tensor representing the current time step. + *args: Possibly additional arguments, in particular `y` for OU processes + + Returns: + x: A PyTorch tensor of the next state. + x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. + """ + pass + + +@CorrectorRegistry.register(name="langevin") +class LangevinCorrector(Corrector): + def __init__(self, sde, score_fn, snr, n_steps): + super().__init__(sde, score_fn, snr, n_steps) + self.score_fn = score_fn + self.n_steps = n_steps + self.snr = snr + + def update_fn(self, x, t, *args): + target_snr = self.snr + for _ in range(self.n_steps): + grad = self.score_fn(x, t, *args) + noise = torch.randn_like(x) + grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0) + x_mean = x + step_size[:, None, None, None] * grad + x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None] + + return x, x_mean + + +@CorrectorRegistry.register(name="ald") +class AnnealedLangevinDynamics(Corrector): + """The original annealed Langevin dynamics predictor in NCSN/NCSNv2.""" + + def __init__(self, sde, score_fn, snr, n_steps): + super().__init__(sde, score_fn, snr, n_steps) + if not isinstance(sde, (sdes.OUVESDE,)): + raise NotImplementedError( + f"SDE class {sde.__class__.__name__} not yet supported." + ) + self.sde = sde + self.score_fn = score_fn + self.snr = snr + self.n_steps = n_steps + + def update_fn(self, x, t, *args): + n_steps = self.n_steps + target_snr = self.snr + std = self.sde.marginal_prob(x, t, *args)[1] + + for _ in range(n_steps): + grad = self.score_fn(x, t, *args) + noise = torch.randn_like(x) + step_size = (target_snr * std) ** 2 * 2 + x_mean = x + step_size[:, None, None, None] * grad + x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None] + + return x, x_mean + + +@CorrectorRegistry.register(name="none") +class NoneCorrector(Corrector): + """An empty corrector that does nothing.""" + + def __init__(self, *args, **kwargs): + self.snr = 0 + self.n_steps = 0 + pass + + def update_fn(self, x, t, *args): + return x, x diff --git a/modules/sgmse/sampling/predictors.py b/modules/sgmse/sampling/predictors.py new file mode 100644 index 00000000..84723f18 --- /dev/null +++ b/modules/sgmse/sampling/predictors.py @@ -0,0 +1,78 @@ +import abc + +import torch +import numpy as np + +from utils.sgmse_util.registry import Registry + + +PredictorRegistry = Registry("Predictor") + + +class Predictor(abc.ABC): + """The abstract class for a predictor algorithm.""" + + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__() + self.sde = sde + self.rsde = sde.reverse(score_fn) + self.score_fn = score_fn + self.probability_flow = probability_flow + + @abc.abstractmethod + def update_fn(self, x, t, *args): + """One update of the predictor. + + Args: + x: A PyTorch tensor representing the current state + t: A Pytorch tensor representing the current time step. + *args: Possibly additional arguments, in particular `y` for OU processes + + Returns: + x: A PyTorch tensor of the next state. + x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. + """ + pass + + def debug_update_fn(self, x, t, *args): + raise NotImplementedError( + f"Debug update function not implemented for predictor {self}." + ) + + +@PredictorRegistry.register("euler_maruyama") +class EulerMaruyamaPredictor(Predictor): + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__(sde, score_fn, probability_flow=probability_flow) + + def update_fn(self, x, t, *args): + dt = -1.0 / self.rsde.N + z = torch.randn_like(x) + f, g = self.rsde.sde(x, t, *args) + x_mean = x + f * dt + x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z + return x, x_mean + + +@PredictorRegistry.register("reverse_diffusion") +class ReverseDiffusionPredictor(Predictor): + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__(sde, score_fn, probability_flow=probability_flow) + + def update_fn(self, x, t, *args): + f, g = self.rsde.discretize(x, t, *args) + z = torch.randn_like(x) + x_mean = x - f + x = x_mean + g[:, None, None, None] * z + return x, x_mean + + +@PredictorRegistry.register("none") +class NonePredictor(Predictor): + """An empty predictor that does nothing.""" + + def __init__(self, *args, **kwargs): + pass + + def update_fn(self, x, t, *args): + return x, x diff --git a/modules/sgmse/sdes.py b/modules/sgmse/sdes.py new file mode 100644 index 00000000..441311bb --- /dev/null +++ b/modules/sgmse/sdes.py @@ -0,0 +1,360 @@ +""" +Abstract SDE classes, Reverse SDE, and VE/VP SDEs. + +Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py +""" + +import abc +import warnings + +import numpy as np +from utils.sgmse_util.tensors import batch_broadcast +import torch + +from utils.sgmse_util.registry import Registry + + +SDERegistry = Registry("SDE") + + +class SDE(abc.ABC): + """SDE abstract class. Functions are designed for a mini-batch of inputs.""" + + def __init__(self, N): + """Construct an SDE. + + Args: + N: number of discretization time steps. + """ + super().__init__() + self.N = N + + @property + @abc.abstractmethod + def T(self): + """End time of the SDE.""" + pass + + @abc.abstractmethod + def sde(self, x, t, *args): + pass + + @abc.abstractmethod + def marginal_prob(self, x, t, *args): + """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$.""" + pass + + @abc.abstractmethod + def prior_sampling(self, shape, *args): + """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`.""" + pass + + @abc.abstractmethod + def prior_logp(self, z): + """Compute log-density of the prior distribution. + + Useful for computing the log-likelihood via probability flow ODE. + + Args: + z: latent code + Returns: + log probability density + """ + pass + + @staticmethod + @abc.abstractmethod + def add_argparse_args(parent_parser): + """ + Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser. + """ + pass + + def discretize(self, x, t, *args): + """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. + + Useful for reverse diffusion sampling and probabiliy flow sampling. + Defaults to Euler-Maruyama discretization. + + Args: + x: a torch tensor + t: a torch float representing the time step (from 0 to `self.T`) + + Returns: + f, G + """ + dt = 1 / self.N + drift, diffusion = self.sde(x, t, *args) + f = drift * dt + G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) + return f, G + + def reverse(oself, score_model, probability_flow=False): + """Create the reverse-time SDE/ODE. + + Args: + score_model: A function that takes x, t and y and returns the score. + probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. + """ + N = oself.N + T = oself.T + sde_fn = oself.sde + discretize_fn = oself.discretize + + # Build the class for reverse-time SDE. + class RSDE(oself.__class__): + def __init__(self): + self.N = N + self.probability_flow = probability_flow + + @property + def T(self): + return T + + def sde(self, x, t, *args): + """Create the drift and diffusion functions for the reverse SDE/ODE.""" + rsde_parts = self.rsde_parts(x, t, *args) + total_drift, diffusion = ( + rsde_parts["total_drift"], + rsde_parts["diffusion"], + ) + return total_drift, diffusion + + def rsde_parts(self, x, t, *args): + sde_drift, sde_diffusion = sde_fn(x, t, *args) + score = score_model(x, t, *args) + score_drift = ( + -sde_diffusion[:, None, None, None] ** 2 + * score + * (0.5 if self.probability_flow else 1.0) + ) + diffusion = ( + torch.zeros_like(sde_diffusion) + if self.probability_flow + else sde_diffusion + ) + total_drift = sde_drift + score_drift + return { + "total_drift": total_drift, + "diffusion": diffusion, + "sde_drift": sde_drift, + "sde_diffusion": sde_diffusion, + "score_drift": score_drift, + "score": score, + } + + def discretize(self, x, t, *args): + """Create discretized iteration rules for the reverse diffusion sampler.""" + f, G = discretize_fn(x, t, *args) + rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, *args) * ( + 0.5 if self.probability_flow else 1.0 + ) + rev_G = torch.zeros_like(G) if self.probability_flow else G + return rev_f, rev_G + + return RSDE() + + @abc.abstractmethod + def copy(self): + pass + + +@SDERegistry.register("ouve") +class OUVESDE(SDE): + @staticmethod + def add_argparse_args(parser): + parser.add_argument( + "--sde-n", + type=int, + default=1000, + help="The number of timesteps in the SDE discretization. 30 by default", + ) + parser.add_argument( + "--theta", + type=float, + default=1.5, + help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.", + ) + parser.add_argument( + "--sigma-min", + type=float, + default=0.05, + help="The minimum sigma to use. 0.05 by default.", + ) + parser.add_argument( + "--sigma-max", + type=float, + default=0.5, + help="The maximum sigma to use. 0.5 by default.", + ) + return parser + + def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs): + """Construct an Ornstein-Uhlenbeck Variance Exploding SDE. + + Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument + to the methods which require it (e.g., `sde` or `marginal_prob`). + + dx = -theta (y-x) dt + sigma(t) dw + + with + + sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min)) + + Args: + theta: stiffness parameter. + sigma_min: smallest sigma. + sigma_max: largest sigma. + N: number of discretization steps + """ + super().__init__(N) + self.theta = theta + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.logsig = np.log(self.sigma_max / self.sigma_min) + self.N = N + + def copy(self): + return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N) + + @property + def T(self): + return 1 + + def sde(self, x, t, y): + drift = self.theta * (y - x) + # the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel + # standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t + # with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution + # unless this sqrt(2*logsig) factor is included. + sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t + diffusion = sigma * np.sqrt(2 * self.logsig) + return drift, diffusion + + def _mean(self, x0, t, y): + theta = self.theta + exp_interp = torch.exp(-theta * t)[:, None, None, None] + return exp_interp * x0 + (1 - exp_interp) * y + + def _std(self, t): + # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde() + sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig + # could maybe replace the two torch.exp(... * t) terms here by cached values **t + return torch.sqrt( + ( + sigma_min**2 + * torch.exp(-2 * theta * t) + * (torch.exp(2 * (theta + logsig) * t) - 1) + * logsig + ) + / (theta + logsig) + ) + + def marginal_prob(self, x0, t, y): + return self._mean(x0, t, y), self._std(t) + + def prior_sampling(self, shape, y): + if shape != y.shape: + warnings.warn( + f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape." + ) + std = self._std(torch.ones((y.shape[0],), device=y.device)) + x_T = y + torch.randn_like(y) * std[:, None, None, None] + return x_T + + def prior_logp(self, z): + raise NotImplementedError("prior_logp for OU SDE not yet implemented!") + + +@SDERegistry.register("ouvp") +class OUVPSDE(SDE): + # !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!! + @staticmethod + def add_argparse_args(parser): + parser.add_argument( + "--sde-n", + type=int, + default=1000, + help="The number of timesteps in the SDE discretization. 1000 by default", + ) + parser.add_argument( + "--beta-min", type=float, required=True, help="The minimum beta to use." + ) + parser.add_argument( + "--beta-max", type=float, required=True, help="The maximum beta to use." + ) + parser.add_argument( + "--stiffness", + type=float, + default=1, + help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.", + ) + return parser + + def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs): + """ + !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!! + + Construct an Ornstein-Uhlenbeck Variance Preserving SDE: + + dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw + + with + + beta(t) = beta_min + t(beta_max - beta_min) + + Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument + to the methods which require it (e.g., `sde` or `marginal_prob`). + + Args: + beta_min: smallest sigma. + beta_max: largest sigma. + stiffness: stiffness factor of the drift. 1 by default. + N: number of discretization steps + """ + super().__init__(N) + self.beta_min = beta_min + self.beta_max = beta_max + self.stiffness = stiffness + self.N = N + + def copy(self): + return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N) + + @property + def T(self): + return 1 + + def _beta(self, t): + return self.beta_min + t * (self.beta_max - self.beta_min) + + def sde(self, x, t, y): + drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x) + diffusion = torch.sqrt(self._beta(t)) + return drift, diffusion + + def _mean(self, x0, t, y): + b0, b1, s = self.beta_min, self.beta_max, self.stiffness + x0y_fac = torch.exp(-0.25 * s * t * (t * (b1 - b0) + 2 * b0))[ + :, None, None, None + ] + return y + x0y_fac * (x0 - y) + + def _std(self, t): + b0, b1, s = self.beta_min, self.beta_max, self.stiffness + return (1 - torch.exp(-0.5 * s * t * (t * (b1 - b0) + 2 * b0))) / s + + def marginal_prob(self, x0, t, y): + return self._mean(x0, t, y), self._std(t) + + def prior_sampling(self, shape, y): + if shape != y.shape: + warnings.warn( + f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape." + ) + std = self._std(torch.ones((y.shape[0],), device=y.device)) + x_T = y + torch.randn_like(y) * std[:, None, None, None] + return x_T + + def prior_logp(self, z): + raise NotImplementedError("prior_logp for OU SDE not yet implemented!") diff --git a/modules/sgmse/shared.py b/modules/sgmse/shared.py new file mode 100644 index 00000000..8165069f --- /dev/null +++ b/modules/sgmse/shared.py @@ -0,0 +1,132 @@ +import functools +import numpy as np + +import torch +import torch.nn as nn + +from utils.sgmse_util.registry import Registry + + +BackboneRegistry = Registry("Backbone") + + +class GaussianFourierProjection(nn.Module): + """Gaussian random features for encoding time steps.""" + + def __init__(self, embed_dim, scale=16, complex_valued=False): + super().__init__() + self.complex_valued = complex_valued + if not complex_valued: + # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities. + # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case, + # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly, + # and this halving is not necessary. + embed_dim = embed_dim // 2 + # Randomly sample weights during initialization. These weights are fixed + # during optimization and are not trainable. + self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False) + + def forward(self, t): + t_proj = t[:, None] * self.W[None, :] * 2 * np.pi + if self.complex_valued: + return torch.exp(1j * t_proj) + else: + return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1) + + +class DiffusionStepEmbedding(nn.Module): + """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017.""" + + def __init__(self, embed_dim, complex_valued=False): + super().__init__() + self.complex_valued = complex_valued + if not complex_valued: + # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities. + # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case, + # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly, + # and this halving is not necessary. + embed_dim = embed_dim // 2 + self.embed_dim = embed_dim + + def forward(self, t): + fac = 10 ** ( + 4 * torch.arange(self.embed_dim, device=t.device) / (self.embed_dim - 1) + ) + inner = t[:, None] * fac[None, :] + if self.complex_valued: + return torch.exp(1j * inner) + else: + return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1) + + +class ComplexLinear(nn.Module): + """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`.""" + + def __init__(self, input_dim, output_dim, complex_valued): + super().__init__() + self.complex_valued = complex_valued + if self.complex_valued: + self.re = nn.Linear(input_dim, output_dim) + self.im = nn.Linear(input_dim, output_dim) + else: + self.lin = nn.Linear(input_dim, output_dim) + + def forward(self, x): + if self.complex_valued: + return (self.re(x.real) - self.im(x.imag)) + 1j * ( + self.re(x.imag) + self.im(x.real) + ) + else: + return self.lin(x) + + +class FeatureMapDense(nn.Module): + """A fully connected layer that reshapes outputs to feature maps.""" + + def __init__(self, input_dim, output_dim, complex_valued=False): + super().__init__() + self.complex_valued = complex_valued + self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued) + + def forward(self, x): + return self.dense(x)[..., None, None] + + +def torch_complex_from_reim(re, im): + return torch.view_as_complex(torch.stack([re, im], dim=-1)) + + +class ArgsComplexMultiplicationWrapper(nn.Module): + """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward(). + + Make a complex-valued module `F` from a real-valued module `f` by applying + complex multiplication rules: + + F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a)) + + where `f1`, `f2` are instances of `f` that do *not* share weights. + + Args: + module_cls (callable): A class or function that returns a Torch module/functional. + Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`, + to construct the real and imaginary component modules. + """ + + def __init__(self, module_cls, *args, **kwargs): + super().__init__() + self.re_module = module_cls(*args, **kwargs) + self.im_module = module_cls(*args, **kwargs) + + def forward(self, x, *args, **kwargs): + return torch_complex_from_reim( + self.re_module(x.real, *args, **kwargs) + - self.im_module(x.imag, *args, **kwargs), + self.re_module(x.imag, *args, **kwargs) + + self.im_module(x.real, *args, **kwargs), + ) + + +ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d) +ComplexConvTranspose2d = functools.partial( + ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d +) diff --git a/preprocessors/wsj0reverb.py b/preprocessors/wsj0reverb.py new file mode 100644 index 00000000..84b71c01 --- /dev/null +++ b/preprocessors/wsj0reverb.py @@ -0,0 +1,187 @@ +import json +from tqdm import tqdm +import os +import torchaudio +from utils import audio +import csv +import random +from text import _clean_text +import librosa +import soundfile as sf +import pyroomacoustics as pra +from scipy.io import wavfile +from glob import glob +from pathlib import Path +import numpy as np + +SEED = 100 +np.random.seed(SEED) + +T60_RANGE = [0.4, 1.0] +SNR_RANGE = [0, 20] +DIM_RANGE = [5, 15, 5, 15, 2, 6] +MIN_DISTANCE_TO_WALL = 1 +MIC_ARRAY_RADIUS = 0.16 +TARGET_T60_SHAPE = {"CI": 0.08, "HA": 0.2} +TARGET_T60_SHAPE = {"CI": 0.10, "HA": 0.2} +TARGETS_CROP = {"CI": 16e-3, "HA": 40e-3} +NB_SAMPLES_PER_ROOM = 1 +CHANNELS = 1 + + +def obtain_clean_file(speech_list, i_sample, sample_rate=16000): + speech, speech_sr = sf.read(speech_list[i_sample]) + speech_basename = os.path.basename(speech_list[i_sample]) + assert ( + speech_sr == sample_rate + ), f"wrong speech sampling rate here: expected {sample_rate} got {speech_sr}" + return speech.squeeze(), speech_sr, speech_basename[:-4] + + +def main(output_path, dataset_path): + print("-" * 10) + print("Dataset splits for {}...\n".format("wsj0reverb")) + dataset = "wsj0reverb" + sample_rate = 16000 + save_dir = os.path.join(output_path, dataset) + os.makedirs(save_dir, exist_ok=True) + wsj0reverb_path = dataset_path + splits = ["valid", "train", "test"] + dic_split = {"valid": "si_dt_05", "train": "si_tr_s", "test": "si_et_05"} + speech_lists = { + split: sorted( + glob(f"{os.path.join(wsj0reverb_path, dic_split[split])}/**/*.wav") + ) + for split in splits + } + + for i_split, split in enumerate(splits): + print("Processing split n° {}: {}...".format(i_split + 1, split)) + + reverberant_output_dir = os.path.join(save_dir, "audio", split, "reverb") + dry_output_dir = os.path.join(save_dir, "audio", split, "anechoic") + noisy_reverberant_output_dir = os.path.join( + save_dir, "audio", split, "noisy_reverb" + ) + if split == "test": + unauralized_output_dir = os.path.join( + save_dir, "audio", split, "unauralized" + ) + + os.makedirs(reverberant_output_dir, exist_ok=True) + os.makedirs(dry_output_dir, exist_ok=True) + if split == "test": + os.makedirs(unauralized_output_dir, exist_ok=True) + + speech_list = speech_lists[split] + real_nb_samples = len(speech_list) + for i_sample in tqdm(range(real_nb_samples)): + if not i_sample % NB_SAMPLES_PER_ROOM: # Generate new room + t60 = np.random.uniform(T60_RANGE[0], T60_RANGE[1]) # Draw T60 + room_dim = np.array( + [ + np.random.uniform(DIM_RANGE[2 * n], DIM_RANGE[2 * n + 1]) + for n in range(3) + ] + ) # Draw Dimensions + center_mic_position = np.array( + [ + np.random.uniform( + MIN_DISTANCE_TO_WALL, room_dim[n] - MIN_DISTANCE_TO_WALL + ) + for n in range(3) + ] + ) # draw source position + source_position = np.array( + [ + np.random.uniform( + MIN_DISTANCE_TO_WALL, room_dim[n] - MIN_DISTANCE_TO_WALL + ) + for n in range(3) + ] + ) # draw source position + mic_array_2d = pra.beamforming.circular_2D_array( + center_mic_position[:-1], CHANNELS, phi0=0, radius=MIC_ARRAY_RADIUS + ) # Compute microphone array + mic_array = np.pad( + mic_array_2d, + ((0, 1), (0, 0)), + mode="constant", + constant_values=center_mic_position[-1], + ) + + ### Reverberant Room + e_absorption, max_order = pra.inverse_sabine( + t60, room_dim + ) # Compute absorption coeff + reverberant_room = pra.ShoeBox( + room_dim, + fs=16000, + materials=pra.Material(e_absorption), + max_order=min(3, max_order), + ) # Create room + reverberant_room.set_ray_tracing() + reverberant_room.add_microphone_array(mic_array) # Add microphone array + + # Pick unauralized files + speech, speech_sr, speech_basename = obtain_clean_file( + speech_list, i_sample, sample_rate=sample_rate + ) + + # Generate reverberant room + reverberant_room.add_source(source_position, signal=speech) + reverberant_room.compute_rir() + reverberant_room.simulate() + t60_real = np.mean(reverberant_room.measure_rt60()).squeeze() + reverberant = np.stack(reverberant_room.mic_array.signals).swapaxes(0, 1) + + e_absorption_dry = 0.99 # For Neural Networks OK but clearly not for WPE + dry_room = pra.ShoeBox( + room_dim, + fs=16000, + materials=pra.Material(e_absorption_dry), + max_order=0, + ) # Create room + dry_room.add_microphone_array(mic_array) # Add microphone array + + # Generate dry room + dry_room.add_source(source_position, signal=speech) + dry_room.compute_rir() + dry_room.simulate() + t60_real_dry = np.mean(dry_room.measure_rt60()).squeeze() + rir_dry = dry_room.rir + dry = np.stack(dry_room.mic_array.signals).swapaxes(0, 1) + dry = np.pad( + dry, + ((0, int(0.5 * sample_rate)), (0, 0)), + mode="constant", + constant_values=0, + ) # Add 1 second of silence after dry (because very dry) so that the reverb is not cut, and all samples have same length + + min_len_sample = min(reverberant.shape[0], dry.shape[0]) + dry = dry[:min_len_sample] + reverberant = reverberant[:min_len_sample] + output_scaling = np.max(reverberant) / 0.9 + + drr = 10 * np.log10( + np.mean(dry**2) / (np.mean(reverberant**2) + 1e-8) + 1e-8 + ) + output_filename = f"{speech_basename}_{i_sample // NB_SAMPLES_PER_ROOM}_{t60_real:.2f}_{drr:.1f}.wav" + + sf.write( + os.path.join(dry_output_dir, output_filename), + 1 / output_scaling * dry, + samplerate=sample_rate, + ) + sf.write( + os.path.join(reverberant_output_dir, output_filename), + 1 / output_scaling * reverberant, + samplerate=sample_rate, + ) + + if split == "test": + sf.write( + os.path.join(unauralized_output_dir, output_filename), + speech, + samplerate=sample_rate, + )