diff --git a/bins/sgmse/inference.py b/bins/sgmse/inference.py
new file mode 100644
index 00000000..d3bdf2ad
--- /dev/null
+++ b/bins/sgmse/inference.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+from argparse import ArgumentParser
+import os
+
+from models.sgmse.dereverberation.dereverberation_inference import (
+ DereverberationInference,
+)
+from utils.util import save_config, load_model_config, load_config
+import numpy as np
+import torch
+
+
+def build_inference(args, cfg):
+ supported_inference = {
+ "dereverberation": DereverberationInference,
+ }
+
+ inference_class = supported_inference[cfg.model_type]
+ inference = inference_class(args, cfg)
+ return inference
+
+
+def build_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--config",
+ type=str,
+ required=True,
+ help="JSON/YAML file for configurations.",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ )
+ parser.add_argument(
+ "--test_dir",
+ type=str,
+ required=True,
+ help="Directory containing the test data (must have subdirectory noisy/)",
+ )
+ parser.add_argument(
+ "--corrector_steps", type=int, default=1, help="Number of corrector steps"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default=None,
+ help="Output dir for saving generated results",
+ )
+ parser.add_argument(
+ "--snr",
+ type=float,
+ default=0.33,
+ help="SNR value for (annealed) Langevin dynmaics.",
+ )
+ parser.add_argument("--N", type=int, default=50, help="Number of reverse steps")
+ parser.add_argument("--local_rank", default=0, type=int)
+ return parser
+
+
+def main():
+ # Parse arguments
+ args = build_parser().parse_args()
+ # args, infer_type = formulate_parser(args)
+
+ # Parse config
+ cfg = load_config(args.config)
+ if torch.cuda.is_available():
+ args.local_rank = torch.device("cuda")
+ else:
+ args.local_rank = torch.device("cpu")
+ print("args: ", args)
+
+ # Build inference
+ inferencer = build_inference(args, cfg)
+
+ # Run inference
+ inferencer.inference()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/bins/sgmse/preprocess.py b/bins/sgmse/preprocess.py
new file mode 100644
index 00000000..22e1e860
--- /dev/null
+++ b/bins/sgmse/preprocess.py
@@ -0,0 +1,53 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import faulthandler
+
+faulthandler.enable()
+import os
+import argparse
+import json
+from multiprocessing import cpu_count
+from utils.util import load_config
+from preprocessors.processor import preprocess_dataset
+
+
+def preprocess(cfg):
+ """Proprocess raw data of single or multiple datasets (in cfg.dataset)
+
+ Args:
+ cfg (dict): dictionary that stores configurations
+ """
+ # Specify the output root path to save the processed data
+ output_path = cfg.preprocess.processed_dir
+ os.makedirs(output_path, exist_ok=True)
+
+ ## Split train and test sets
+ for dataset in cfg.dataset:
+ print("Preprocess {}...".format(dataset))
+
+ preprocess_dataset(
+ dataset,
+ cfg.dataset_path[dataset],
+ output_path,
+ cfg.preprocess,
+ cfg.task_type,
+ is_custom_dataset=dataset in cfg.use_custom_dataset,
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config", default="config.json", help="json files for configurations."
+ )
+ parser.add_argument("--num_workers", type=int, default=int(cpu_count()))
+ args = parser.parse_args()
+ cfg = load_config(args.config)
+ preprocess(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/bins/sgmse/train_sgmse.py b/bins/sgmse/train_sgmse.py
new file mode 100644
index 00000000..11a7a004
--- /dev/null
+++ b/bins/sgmse/train_sgmse.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+import torch
+from models.sgmse.dereverberation.dereverberation_Trainer import DereverberationTrainer
+
+from utils.util import load_config
+
+
+def build_trainer(args, cfg):
+ supported_trainer = {
+ "dereverberation": DereverberationTrainer,
+ }
+
+ trainer_class = supported_trainer[cfg.model_type]
+ trainer = trainer_class(args, cfg)
+ return trainer
+
+
+def cuda_relevant(deterministic=False):
+ torch.cuda.empty_cache()
+ # TF32 on Ampere and above
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.enabled = True
+ torch.backends.cudnn.allow_tf32 = True
+ # Deterministic
+ torch.backends.cudnn.deterministic = deterministic
+ torch.backends.cudnn.benchmark = not deterministic
+ torch.use_deterministic_algorithms(deterministic)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config",
+ default="config.json",
+ help="json files for configurations.",
+ required=True,
+ )
+ parser.add_argument(
+ "--num_workers", type=int, default=4, help="Number of dataloader workers."
+ )
+ parser.add_argument(
+ "--exp_name",
+ type=str,
+ default="exp_name",
+ help="A specific name to note the experiment",
+ required=True,
+ )
+ parser.add_argument(
+ "--log_level", default="warning", help="logging level (debug, info, warning)"
+ )
+ parser.add_argument("--stdout_interval", default=5, type=int)
+ parser.add_argument("--local_rank", default=0, type=int)
+ args = parser.parse_args()
+ cfg = load_config(args.config)
+ cfg.exp_name = args.exp_name
+ args.log_dir = os.path.join(cfg.log_dir, args.exp_name)
+ os.makedirs(args.log_dir, exist_ok=True)
+ # Data Augmentation
+ if cfg.preprocess.data_augment:
+ new_datasets_list = []
+ for dataset in cfg.preprocess.data_augment:
+ new_datasets = [
+ # f"{dataset}_pitch_shift",
+ # f"{dataset}_formant_shift",
+ f"{dataset}_equalizer",
+ f"{dataset}_time_stretch",
+ ]
+ new_datasets_list.extend(new_datasets)
+ cfg.dataset.extend(new_datasets_list)
+
+ # CUDA settings
+ cuda_relevant()
+
+ # Build trainer
+ trainer = build_trainer(args, cfg)
+
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/config/sgmse.json b/config/sgmse.json
new file mode 100644
index 00000000..000eb3ab
--- /dev/null
+++ b/config/sgmse.json
@@ -0,0 +1,42 @@
+{
+ "base_config": "config/base.json",
+ "dataset": [
+ "wsj0reverb"
+ ],
+ "task_type": "sgmse",
+ "preprocess": {
+ "dummy": false,
+ "num_frames":256,
+ "normalize": "noisy",
+ "hop_length": 128,
+ "n_fft": 510,
+ "spec_abs_exponent": 0.5,
+ "spec_factor": 0.15,
+ "use_spkid": false,
+ "use_uv": false,
+ "use_frame_pitch": false,
+ "use_phone_pitch": false,
+ "use_frame_energy": false,
+ "use_phone_energy": false,
+ "use_mel": false,
+ "use_audio": false,
+ "use_label": false,
+ "use_one_hot": false
+ },
+ "model": {
+ "sgmse": {
+ "backbone": "ncsnpp",
+ "sde": "ouve",
+
+ "gpus": 1
+ }
+ },
+ "train": {
+ "batch_size": 8,
+ "lr": 1e-4,
+ "ema_decay": 0.999,
+ "t_eps": 3e-2,
+ "num_eval_files": 20
+ }
+
+}
\ No newline at end of file
diff --git a/egs/sgmse/README.md b/egs/sgmse/README.md
new file mode 100644
index 00000000..4765f4a1
--- /dev/null
+++ b/egs/sgmse/README.md
@@ -0,0 +1,98 @@
+# Amphion Speech Enhancement and Dereverberation with Diffusion-based Generative Models Recipe
+
+
+
+
+

+
+
+This repository contains the PyTorch implementations for the 2023 papers and also adapted from [sgmse](https://github.com/sp-uhh/sgmse):
+- Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann. [*"Speech Enhancement and Dereverberation with Diffusion-Based Generative Models"*](https://ieeexplore.ieee.org/abstract/document/10149431), IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023.
+
+
+You can use any sgmse architecture with any dataset you want. There are three steps in total:
+
+1. Data preparation
+2. Training
+3. Inference
+
+
+> **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
+> ```bash
+> cd Amphion
+> ```
+
+## 1. Data Preparation
+
+You can train the vocoder with any datasets. Amphion's supported open-source datasets are detailed [here](../../../datasets/README.md).
+
+### Configuration
+
+Specify the dataset path in `exp_config_base.json`. Note that you can change the `dataset` list to use your preferred datasets.
+
+```json
+"dataset": [
+ "wsj0reverb"
+ ],
+ "dataset_path": {
+ // TODO: Fill in your dataset path
+ "wsj0reverb": ""
+ },
+"preprocess": {
+ "processed_dir": "",
+ "sample_rate": 16000
+ },
+```
+
+## 2. Training
+
+### Configuration
+
+We provide the default hyparameters in the `exp_config_base.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines.
+
+```json
+ "train": {
+ // TODO: Fill in your checkpoint path
+ "checkpoint": "",
+ "adam": {
+ "lr": 1e-4
+ },
+ "ddp": false,
+ "batch_size": 8,
+ "epochs": 200000,
+ "save_checkpoints_steps": 800,
+ "save_summary_steps": 1000,
+ "max_steps": 1000000,
+ "ema_decay": 0.999,
+ "valid_interval": 800,
+ "t_eps": 3e-2,
+ "num_eval_files": 20
+
+ }
+}
+```
+
+### Run
+
+Run the `run.sh` as the training stage (set `--stage 2`).
+
+```bash
+sh egs/sgmse/dereverberation/run.sh --stage 2
+```
+
+> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
+
+## 3. Inference
+
+### Run
+
+Run the `run.sh` as the training stage (set `--stage 3`)
+
+```bash
+sh egs/sgmse/dereverberation/run.sh --stage 3
+ --checkpoint_path [your path]
+ --test_dir [your path]
+ --output_dir [your path]
+
+```
+
diff --git a/egs/sgmse/dereverberation/exp_config.json b/egs/sgmse/dereverberation/exp_config.json
new file mode 100644
index 00000000..cf5c4872
--- /dev/null
+++ b/egs/sgmse/dereverberation/exp_config.json
@@ -0,0 +1,69 @@
+{
+ "base_config": "config/sgmse.json",
+ "model_type": "dereverberation",
+ "dataset": [
+ "wsj0reverb"
+ ],
+ "dataset_path": {
+ // TODO: Fill in your dataset path
+ "wsj0reverb": ""
+ },
+ "log_dir": "",
+ "preprocess": {
+ "processed_dir": "",
+ "sample_rate": 16000
+ },
+ "model": {
+ "sgmse": {
+ "backbone": "ncsnpp",
+ "sde": "ouve",
+ "ncsnpp": {
+ "scale_by_sigma": true,
+ "nonlinearity": "swish",
+ "nf": 128,
+ "ch_mult": [1, 1, 2, 2, 2, 2, 2],
+ "num_res_blocks": 2,
+ "attn_resolutions": [16],
+ "resamp_with_conv": true,
+ "conditional": true,
+ "fir": true,
+ "fir_kernel": [1, 3, 3, 1],
+ "skip_rescale": true,
+ "resblock_type": "biggan",
+ "progressive": "output_skip",
+ "progressive_input": "input_skip",
+ "progressive_combine": "sum",
+ "init_scale": 0.0,
+ "fourier_scale": 16,
+ "image_size": 256,
+ "embedding_type": "fourier",
+ "dropout": 0.0,
+ "centered": true
+ },
+ "ouve": {
+ "theta": 1.5,
+ "sigma_min": 0.05,
+ "sigma_max": 0.5,
+ "N": 1000
+ },
+ "gpus": 1
+ }
+ },
+ "train": {
+ "checkpoint": "",
+ "adam": {
+ "lr": 1e-4
+ },
+ "ddp": false,
+ "batch_size": 8,
+ "epochs": 200000,
+ "save_checkpoints_steps": 800,
+ "save_summary_steps": 1000,
+ "max_steps": 1000000,
+ "ema_decay": 0.999,
+ "valid_interval": 800,
+ "t_eps": 3e-2,
+ "num_eval_files": 20
+
+ }
+}
diff --git a/egs/sgmse/dereverberation/run.sh b/egs/sgmse/dereverberation/run.sh
new file mode 100644
index 00000000..b1aa00f4
--- /dev/null
+++ b/egs/sgmse/dereverberation/run.sh
@@ -0,0 +1,97 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+######## Build Experiment Environment ###########
+exp_dir=$(cd `dirname $0`; pwd)
+work_dir=$(dirname $(dirname $(dirname $exp_dir)))
+
+
+export WORK_DIR=$work_dir
+export PYTHONPATH=$work_dir
+export PYTHONIOENCODING=UTF-8
+export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:100"
+
+######## Parse the Given Parameters from the Commond ###########
+options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,checkpoint:,resume_type:,main_process_port:,infer_mode:,infer_datasets:,infer_feature_dir:,infer_audio_dir:,infer_expt_dir:,infer_output_dir:,checkpoint_path:,test_dir:,output_dir: -- "$@")
+eval set -- "$options"
+export CUDA_VISIBLE_DEVICES="0"
+while true; do
+ case $1 in
+ # Experimental Configuration File
+ -c | --config) shift; exp_config=$1 ; shift ;;
+ # Experimental Name
+ -n | --name) shift; exp_name=$1 ; shift ;;
+ # Running Stage
+ -s | --stage) shift; running_stage=$1 ; shift ;;
+ # Visible GPU machines. The default value is "0".
+ --gpu) shift; gpu=$1 ; shift ;;
+ --checkpoint_path) shift; checkpoint_path=$1 ; shift ;;
+ --test_dir) shift; test_dir=$1 ; shift ;;
+ --output_dir) shift; output_dir=$1 ; shift ;;
+ # [Only for Training] The specific checkpoint path that you want to resume from.
+ --checkpoint) shift; checkpoint=$1 ; shift ;;
+ # [Only for Traiing] `main_process_port` for multi gpu training
+ --main_process_port) shift; main_process_port=$1 ; shift ;;
+
+ # [Only for Inference] The inference mode
+ --infer_mode) shift; infer_mode=$1 ; shift ;;
+ # [Only for Inference] The inferenced datasets
+ --infer_datasets) shift; infer_datasets=$1 ; shift ;;
+ # [Only for Inference] The feature dir for inference
+ --infer_feature_dir) shift; infer_feature_dir=$1 ; shift ;;
+ # [Only for Inference] The audio dir for inference
+ --infer_audio_dir) shift; infer_audio_dir=$1 ; shift ;;
+ # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
+ --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
+ # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
+ --infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
+
+ --) shift ; break ;;
+ *) echo "Invalid option: $1" exit 1 ;;
+ esac
+done
+
+
+### Value check ###
+if [ -z "$running_stage" ]; then
+ echo "[Error] Please specify the running stage"
+ exit 1
+fi
+
+if [ -z "$exp_config" ]; then
+ exp_config="${exp_dir}"/exp_config.json
+fi
+echo "Exprimental Configuration File: $exp_config"
+
+if [ -z "$gpu" ]; then
+ gpu="0"
+fi
+
+if [ -z "$main_process_port" ]; then
+ main_process_port=29500
+fi
+echo "Main Process Port: $main_process_port"
+
+######## Features Extraction ###########
+if [ $running_stage -eq 1 ]; then
+ CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/sgmse/preprocess.py \
+ --config $exp_config \
+ --num_workers 8
+fi
+######## Training ###########
+if [ $running_stage -eq 2 ]; then
+ CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/sgmse/train_sgmse.py \
+ --config "$exp_config" \
+ --exp_name "$exp_name" \
+ --log_level info
+ fi
+
+if [ $running_stage -eq 3 ]; then
+ CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/sgmse/inference.py \
+ --config=$exp_config \
+ --checkpoint_path=$checkpoint_path \
+ --test_dir="$test_dir" \
+ --output_dir=$output_dir
+ fi
\ No newline at end of file
diff --git a/env.sh b/env.sh
index 10ef7ff1..1722c4ae 100644
--- a/env.sh
+++ b/env.sh
@@ -28,5 +28,6 @@ pip install phonemizer==3.2.1 pypinyin==0.48.0
pip install black==24.1.1
+pip install torch-ema ninja
# Uninstall nvidia-cublas-cu11 if there exist some bugs about CUDA version
# pip uninstall nvidia-cublas-cu11
diff --git a/imgs/sgmse/diffusion_process.png b/imgs/sgmse/diffusion_process.png
new file mode 100644
index 00000000..6f8a6db0
Binary files /dev/null and b/imgs/sgmse/diffusion_process.png differ
diff --git a/models/base/base_trainer.py b/models/base/base_trainer.py
index 8782216d..5c18274c 100644
--- a/models/base/base_trainer.py
+++ b/models/base/base_trainer.py
@@ -78,9 +78,11 @@ def __init__(self, args, cfg):
self.criterion = self.build_criterion()
if isinstance(self.criterion, dict):
for key, value in self.criterion.items():
- self.criterion[key].cuda(args.local_rank)
+ if not callable(value):
+ self.criterion[key].cuda(args.local_rank)
else:
- self.criterion.cuda(self.args.local_rank)
+ if not callable(self.criterion):
+ self.criterion.cuda(self.args.local_rank)
# optimizer
self.optimizer = self.build_optimizer()
diff --git a/models/sgmse/dereverberation/__init__.py b/models/sgmse/dereverberation/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/models/sgmse/dereverberation/dereverberation.py b/models/sgmse/dereverberation/dereverberation.py
new file mode 100644
index 00000000..4d607096
--- /dev/null
+++ b/models/sgmse/dereverberation/dereverberation.py
@@ -0,0 +1,26 @@
+import torch
+from torch import nn
+from types import SimpleNamespace
+from modules.sgmse.sdes import SDERegistry
+from modules.sgmse.shared import BackboneRegistry
+import json
+
+
+class ScoreModel(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ # Initialize Backbone DNN
+ dnn_cls = BackboneRegistry.get_by_name(cfg.backbone)
+ dnn_cfg = cfg[cfg.backbone]
+ self.dnn = dnn_cls(**dnn_cfg)
+ # Initialize SDE
+ sde_cls = SDERegistry.get_by_name(cfg.sde)
+ sde_cfg = cfg[cfg.sde]
+ self.sde = sde_cls(**sde_cfg)
+
+ def forward(self, x, t, y):
+ # Concatenate y as an extra channel
+ dnn_input = torch.cat([x, y], dim=1)
+ score = -self.dnn(dnn_input, t)
+ return score
diff --git a/models/sgmse/dereverberation/dereverberation_Trainer.py b/models/sgmse/dereverberation/dereverberation_Trainer.py
new file mode 100644
index 00000000..ccf97263
--- /dev/null
+++ b/models/sgmse/dereverberation/dereverberation_Trainer.py
@@ -0,0 +1,214 @@
+from models.base.base_trainer import BaseTrainer
+from models.sgmse.dereverberation.dereverberation_dataset import Specs
+import torch
+import torch.nn as nn
+from torch.nn import MSELoss, L1Loss
+import torch.nn.functional as F
+from utils.sgmse_util.inference import evaluate_model
+from torch_ema import ExponentialMovingAverage
+from models.sgmse.dereverberation.dereverberation import ScoreModel
+from modules.sgmse import sampling
+from torch.utils.data import DataLoader
+import os
+
+
+class DereverberationTrainer(BaseTrainer):
+ def __init__(self, args, cfg):
+ BaseTrainer.__init__(self, args, cfg)
+ self.cfg = cfg
+ self.save_config_file()
+ self.ema = ExponentialMovingAverage(
+ self.model.parameters(), decay=self.cfg.train.ema_decay
+ )
+ self._error_loading_ema = False
+ self.t_eps = self.cfg.train.t_eps
+ self.num_eval_files = self.cfg.train.num_eval_files
+ self.data_loader = self.build_data_loader()
+ self.save_config_file()
+
+ checkpoint = self.load_checkpoint()
+ if checkpoint:
+ self.load_model(checkpoint)
+
+ def build_dataset(self):
+ return Specs
+
+ def load_checkpoint(self):
+ model_path = self.cfg.train.checkpoint
+ if not model_path or not os.path.exists(model_path):
+ self.logger.info("No checkpoint to load or checkpoint path does not exist.")
+ return None
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
+ self.logger.info(f"Re(store) from {model_path}")
+ checkpoint = torch.load(model_path, map_location="cpu")
+ if "ema" in checkpoint:
+ try:
+ self.ema.load_state_dict(checkpoint["ema"])
+ except:
+ self._error_loading_ema = True
+ warnings.warn("EMA state_dict not found in checkpoint!")
+ return checkpoint
+
+ def build_data_loader(self):
+ Dataset = self.build_dataset()
+ train_set = Dataset(self.cfg, subset="train", shuffle_spec=True)
+ train_loader = DataLoader(
+ train_set,
+ batch_size=self.cfg.train.batch_size,
+ num_workers=self.args.num_workers,
+ pin_memory=False,
+ shuffle=True,
+ )
+ self.valid_set = Dataset(self.cfg, subset="valid", shuffle_spec=False)
+ valid_loader = DataLoader(
+ self.valid_set,
+ batch_size=self.cfg.train.batch_size,
+ num_workers=self.args.num_workers,
+ pin_memory=False,
+ shuffle=False,
+ )
+ data_loader = {"train": train_loader, "valid": valid_loader}
+ return data_loader
+
+ def build_optimizer(self):
+ optimizer = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
+ return optimizer
+
+ def build_scheduler(self):
+ return None
+ # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
+
+ def build_singers_lut(self):
+ return None
+
+ def write_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar(key, value, self.step)
+
+ def write_valid_summary(self, losses, stats):
+ for key, value in losses.items():
+ self.sw.add_scalar(key, value, self.step)
+
+ def _loss(self, err):
+ losses = torch.square(err.abs())
+ loss = torch.mean(0.5 * torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
+ return loss
+
+ def build_criterion(self):
+ return self._loss
+
+ def get_state_dict(self):
+ state_dict = {
+ "model": self.model.state_dict(),
+ "optimizer": self.optimizer.state_dict(),
+ "step": self.step,
+ "epoch": self.epoch,
+ "batch_size": self.cfg.train.batch_size,
+ "ema": self.ema.state_dict(),
+ }
+ if self.scheduler is not None:
+ state_dict["scheduler"] = self.scheduler.state_dict()
+ return state_dict
+
+ def load_model(self, checkpoint):
+ self.step = checkpoint["step"]
+ self.epoch = checkpoint["epoch"]
+
+ self.model.load_state_dict(checkpoint["model"])
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ if "scheduler" in checkpoint and self.scheduler is not None:
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
+ if "ema" in checkpoint:
+ self.ema.load_state_dict(checkpoint["ema"])
+
+ def build_model(self):
+ self.model = ScoreModel(self.cfg.model.sgmse)
+ return self.model
+
+ def get_pc_sampler(
+ self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs
+ ):
+ N = self.model.sde.N if N is None else N
+ sde = self.model.sde.copy()
+ sde.N = N
+
+ kwargs = {"eps": self.t_eps, **kwargs}
+ if minibatch is None:
+ return sampling.get_pc_sampler(
+ predictor_name,
+ corrector_name,
+ sde=sde,
+ score_fn=self.model,
+ y=y,
+ **kwargs,
+ )
+ else:
+ M = y.shape[0]
+
+ def batched_sampling_fn():
+ samples, ns = [], []
+ for i in range(int(ceil(M / minibatch))):
+ y_mini = y[i * minibatch : (i + 1) * minibatch]
+ sampler = sampling.get_pc_sampler(
+ predictor_name,
+ corrector_name,
+ sde=sde,
+ score_fn=self.model,
+ y=y_mini,
+ **kwargs,
+ )
+ sample, n = sampler()
+ samples.append(sample)
+ ns.append(n)
+ samples = torch.cat(samples, dim=0)
+ return samples, ns
+
+ return batched_sampling_fn
+
+ def _step(self, batch):
+ x = batch["X"]
+ y = batch["Y"]
+
+ t = (
+ torch.rand(x.shape[0], device=x.device) * (self.model.sde.T - self.t_eps)
+ + self.t_eps
+ )
+ mean, std = self.model.sde.marginal_prob(x, t, y)
+
+ z = torch.randn_like(x)
+ sigmas = std[:, None, None, None]
+ perturbed_data = mean + sigmas * z
+ score = self.model(perturbed_data, t, y)
+
+ err = score * sigmas + z
+ loss = self.criterion(err)
+ return loss
+
+ def train_step(self, batch):
+ loss = self._step(batch)
+
+ # Backward pass and optimization
+ self.optimizer.zero_grad() # reset gradient
+ loss.backward()
+ self.optimizer.step()
+
+ # Update the EMA of the model parameters
+ self.ema.update(self.model.parameters())
+
+ self.write_summary({"train_loss": loss.item()}, {})
+ return {"train_loss": loss.item()}, {}, loss.item()
+
+ def eval_step(self, batch, batch_idx):
+ self.ema.store(self.model.parameters())
+ self.ema.copy_to(self.model.parameters())
+ loss = self._step(batch)
+ self.write_valid_summary({"valid_loss": loss.item()}, {})
+ if batch_idx == 0 and self.num_eval_files != 0:
+ pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
+ self.write_valid_summary(
+ {"pesq": pesq, "si_sdr": si_sdr, "estoi": estoi}, {}
+ )
+ print(f" pesq={pesq}, si_sdr={si_sdr}, estoi={estoi}")
+ if self.ema.collected_params is not None:
+ self.ema.restore(self.model.parameters())
+ return {"valid_loss": loss.item()}, {}, loss.item()
diff --git a/models/sgmse/dereverberation/dereverberation_dataset.py b/models/sgmse/dereverberation/dereverberation_dataset.py
new file mode 100644
index 00000000..00b39262
--- /dev/null
+++ b/models/sgmse/dereverberation/dereverberation_dataset.py
@@ -0,0 +1,114 @@
+import torch
+from glob import glob
+from torchaudio import load
+import numpy as np
+import torch.nn.functional as F
+import os
+from os.path import join
+
+
+class Specs:
+ def __init__(self, cfg, subset, shuffle_spec):
+ self.cfg = cfg
+ self.data_dir = os.path.join(
+ cfg.preprocess.processed_dir, cfg.dataset[0], "audio"
+ )
+ self.clean_files = sorted(glob(join(self.data_dir, subset) + "/anechoic/*.wav"))
+ self.noisy_files = sorted(glob(join(self.data_dir, subset) + "/reverb/*.wav"))
+ self.dummy = cfg.preprocess.dummy
+ self.num_frames = cfg.preprocess.num_frames
+ self.shuffle_spec = shuffle_spec
+ self.normalize = cfg.preprocess.normalize
+ self.hop_length = cfg.preprocess.hop_length
+ self.n_fft = cfg.preprocess.n_fft
+ self.window = self.get_window(self.n_fft)
+ self.windows = {}
+ self.spec_abs_exponent = cfg.preprocess.spec_abs_exponent
+ self.spec_factor = cfg.preprocess.spec_factor
+
+ def __getitem__(self, i):
+ x, _ = load(self.clean_files[i])
+ y, _ = load(self.noisy_files[i])
+
+ # formula applies for center=True
+ target_len = (self.num_frames - 1) * self.hop_length
+ current_len = x.size(-1)
+ pad = max(target_len - current_len, 0)
+ if pad == 0:
+ # extract random part of the audio file
+ if self.shuffle_spec:
+ start = int(np.random.uniform(0, current_len - target_len))
+ else:
+ start = int((current_len - target_len) / 2)
+ x = x[..., start : start + target_len]
+ y = y[..., start : start + target_len]
+ else:
+ # pad audio if the length T is smaller than num_frames
+ x = F.pad(x, (pad // 2, pad // 2 + (pad % 2)), mode="constant")
+ y = F.pad(y, (pad // 2, pad // 2 + (pad % 2)), mode="constant")
+
+ # normalize w.r.t to the noisy or the clean signal or not at all
+ # to ensure same clean signal power in x and y.
+ if self.normalize == "noisy":
+ normfac = y.abs().max()
+ elif self.normalize == "clean":
+ normfac = x.abs().max()
+ elif self.normalize == "not":
+ normfac = 1.0
+ x = x / normfac
+ y = y / normfac
+
+ X = torch.stft(x, **self.stft_kwargs())
+ Y = torch.stft(y, **self.stft_kwargs())
+ X, Y = self.spec_transform(X), self.spec_transform(Y)
+ return {"X": X, "Y": Y}
+
+ def __len__(self):
+ if self.dummy:
+ return int(len(self.clean_files) / 200)
+ else:
+ return len(self.clean_files)
+
+ def spec_transform(self, spec):
+ if self.spec_abs_exponent != 1:
+ e = self.spec_abs_exponent
+ spec = spec.abs() ** e * torch.exp(1j * spec.angle())
+ spec = spec * self.spec_factor
+
+ return spec
+
+ def stft_kwargs(self):
+ return {**self.istft_kwargs(), "return_complex": True}
+
+ def istft_kwargs(self):
+ return dict(
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=self.window,
+ center=True,
+ )
+
+ def stft(self, sig):
+ window = self._get_window(sig)
+ return torch.stft(sig, **{**self.stft_kwargs(), "window": window})
+
+ def istft(self, spec, length=None):
+ window = self._get_window(spec)
+ return torch.istft(
+ spec, **{**self.istft_kwargs(), "window": window, "length": length}
+ )
+
+ @staticmethod
+ def get_window(window_length):
+ return torch.hann_window(window_length, periodic=True)
+
+ def _get_window(self, x):
+ """
+ Retrieve an appropriate window for the given tensor x, matching the device.
+ Caches the retrieved windows so that only one window tensor will be allocated per device.
+ """
+ window = self.windows.get(x.device, None)
+ if window is None:
+ window = self.window.to(x.device)
+ self.windows[x.device] = window
+ return window
diff --git a/models/sgmse/dereverberation/dereverberation_inference.py b/models/sgmse/dereverberation/dereverberation_inference.py
new file mode 100644
index 00000000..c3c847c1
--- /dev/null
+++ b/models/sgmse/dereverberation/dereverberation_inference.py
@@ -0,0 +1,81 @@
+import time
+import numpy as np
+import torch
+from tqdm import tqdm
+import torch.nn as nn
+from collections import OrderedDict
+from models.sgmse.dereverberation.dereverberation import ScoreModel
+from models.sgmse.dereverberation.dereverberation_dataset import Specs
+from models.sgmse.dereverberation.dereverberation_Trainer import DereverberationTrainer
+import json
+from os.path import join
+import glob
+from torchaudio import load
+from soundfile import write
+from utils.sgmse_util.other import ensure_dir, pad_spec
+
+
+class DereverberationInference:
+ def __init__(self, args, cfg):
+ self.cfg = cfg
+ self.t_eps = self.cfg.train.t_eps
+ self.args = args
+ self.test_dir = args.test_dir
+ self.target_dir = self.args.output_dir
+ self.model = self.build_model()
+ self.load_state_dict()
+
+ def build_model(self):
+ self.model = ScoreModel(self.cfg.model.sgmse)
+ return self.model
+
+ def load_state_dict(self):
+ self.checkpoint_path = self.args.checkpoint_path
+ checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
+ self.model.load_state_dict(checkpoint["model"])
+ self.model.cuda(self.args.local_rank)
+
+ def inference(self):
+ sr = 16000
+ snr = self.args.snr
+ N = self.args.N
+ corrector_steps = self.args.corrector_steps
+ self.model.eval()
+ noisy_dir = join(self.test_dir, "noisy/")
+ noisy_files = sorted(glob.glob("{}/*.wav".format(noisy_dir)))
+ for noisy_file in tqdm(noisy_files):
+ filename = noisy_file.split("/")[-1]
+
+ # Load wav
+ y, _ = load(noisy_file)
+ T_orig = y.size(1)
+
+ # Normalize
+ norm_factor = y.abs().max()
+ y = y / norm_factor
+
+ # Prepare DNN input
+ spec = Specs(self.cfg, subset="", shuffle_spec=False)
+ Y = torch.unsqueeze(spec.spec_transform(spec.stft(sig=y.cuda())), 0)
+ Y = pad_spec(Y)
+
+ # Reverse sampling
+ sampler = DereverberationTrainer.get_pc_sampler(
+ self,
+ "reverse_diffusion",
+ "ald",
+ Y.cuda(),
+ N=N,
+ corrector_steps=corrector_steps,
+ snr=snr,
+ )
+ sample, _ = sampler()
+
+ # Backward transform in time domain
+ x_hat = spec.istft(sample.squeeze(), T_orig)
+
+ # Renormalize
+ x_hat = x_hat * norm_factor
+
+ # Write enhanced wav file
+ write(join(self.target_dir, filename), x_hat.cpu().numpy(), 16000)
diff --git a/modules/sgmse/__init__.py b/modules/sgmse/__init__.py
new file mode 100644
index 00000000..ff6c52ef
--- /dev/null
+++ b/modules/sgmse/__init__.py
@@ -0,0 +1,5 @@
+from .shared import BackboneRegistry
+from .ncsnpp import NCSNpp
+from .dcunet import DCUNet
+
+__all__ = ["BackboneRegistry", "NCSNpp", "DCUNet"]
diff --git a/modules/sgmse/dcunet.py b/modules/sgmse/dcunet.py
new file mode 100644
index 00000000..9815a76e
--- /dev/null
+++ b/modules/sgmse/dcunet.py
@@ -0,0 +1,765 @@
+from functools import partial
+import numpy as np
+
+import torch
+from torch import nn, Tensor
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from .shared import (
+ BackboneRegistry,
+ ComplexConv2d,
+ ComplexConvTranspose2d,
+ ComplexLinear,
+ DiffusionStepEmbedding,
+ GaussianFourierProjection,
+ FeatureMapDense,
+ torch_complex_from_reim,
+)
+
+
+def get_activation(name):
+ if name == "silu":
+ return nn.SiLU
+ elif name == "relu":
+ return nn.ReLU
+ elif name == "leaky_relu":
+ return nn.LeakyReLU
+ else:
+ raise NotImplementedError(f"Unknown activation: {name}")
+
+
+class BatchNorm(_BatchNorm):
+ def _check_input_dim(self, input):
+ if input.dim() < 2 or input.dim() > 4:
+ raise ValueError(
+ "expected 4D or 3D input (got {}D input)".format(input.dim())
+ )
+
+
+class OnReIm(nn.Module):
+ def __init__(self, module_cls, *args, **kwargs):
+ super().__init__()
+ self.re_module = module_cls(*args, **kwargs)
+ self.im_module = module_cls(*args, **kwargs)
+
+ def forward(self, x):
+ return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
+
+
+# Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers!
+
+
+def unet_decoder_args(encoders, *, skip_connections):
+ """Get list of decoder arguments for upsampling (right) side of a symmetric u-net,
+ given the arguments used to construct the encoder.
+ Args:
+ encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)):
+ List of arguments used to construct the encoders
+ skip_connections (bool): Whether to include skip connections in the
+ calculation of decoder input channels.
+ Return:
+ tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding):
+ Arguments to be used to construct decoders
+ """
+ decoder_args = []
+ for (
+ enc_in_chan,
+ enc_out_chan,
+ enc_kernel_size,
+ enc_stride,
+ enc_padding,
+ enc_dilation,
+ ) in reversed(encoders):
+ if skip_connections and decoder_args:
+ skip_in_chan = enc_out_chan
+ else:
+ skip_in_chan = 0
+ decoder_args.append(
+ (
+ enc_out_chan + skip_in_chan,
+ enc_in_chan,
+ enc_kernel_size,
+ enc_stride,
+ enc_padding,
+ enc_dilation,
+ )
+ )
+ return tuple(decoder_args)
+
+
+def make_unet_encoder_decoder_args(encoder_args, decoder_args):
+ encoder_args = tuple(
+ (
+ in_chan,
+ out_chan,
+ tuple(kernel_size),
+ tuple(stride),
+ (
+ tuple([n // 2 for n in kernel_size])
+ if padding == "auto"
+ else tuple(padding)
+ ),
+ tuple(dilation),
+ )
+ for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args
+ )
+
+ if decoder_args == "auto":
+ decoder_args = unet_decoder_args(
+ encoder_args,
+ skip_connections=True,
+ )
+ else:
+ decoder_args = tuple(
+ (
+ in_chan,
+ out_chan,
+ tuple(kernel_size),
+ tuple(stride),
+ tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding,
+ tuple(dilation),
+ output_padding,
+ )
+ for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args
+ )
+
+ return encoder_args, decoder_args
+
+
+DCUNET_ARCHITECTURES = {
+ "DCUNet-10": make_unet_encoder_decoder_args(
+ # Encoders:
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
+ (
+ (1, 32, (7, 5), (2, 2), "auto", (1, 1)),
+ (32, 64, (7, 5), (2, 2), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 2), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 2), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 1), "auto", (1, 1)),
+ ),
+ # Decoders: automatic inverse
+ "auto",
+ ),
+ "DCUNet-16": make_unet_encoder_decoder_args(
+ # Encoders:
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
+ (
+ (1, 32, (7, 5), (2, 2), "auto", (1, 1)),
+ (32, 32, (7, 5), (2, 1), "auto", (1, 1)),
+ (32, 64, (7, 5), (2, 2), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 1), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 2), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 1), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 2), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 1), "auto", (1, 1)),
+ ),
+ # Decoders: automatic inverse
+ "auto",
+ ),
+ "DCUNet-20": make_unet_encoder_decoder_args(
+ # Encoders:
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
+ (
+ (1, 32, (7, 1), (1, 1), "auto", (1, 1)),
+ (32, 32, (1, 7), (1, 1), "auto", (1, 1)),
+ (32, 64, (7, 5), (2, 2), "auto", (1, 1)),
+ (64, 64, (7, 5), (2, 1), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 2), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 1), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 2), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 1), "auto", (1, 1)),
+ (64, 64, (5, 3), (2, 2), "auto", (1, 1)),
+ (64, 90, (5, 3), (2, 1), "auto", (1, 1)),
+ ),
+ # Decoders: automatic inverse
+ "auto",
+ ),
+ "DilDCUNet-v2": make_unet_encoder_decoder_args( # architecture used in SGMSE / Interspeech paper
+ # Encoders:
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
+ (
+ (1, 32, (4, 4), (1, 1), "auto", (1, 1)),
+ (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
+ (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
+ (32, 64, (4, 4), (2, 1), "auto", (2, 1)),
+ (64, 128, (4, 4), (2, 2), "auto", (4, 1)),
+ (128, 256, (4, 4), (2, 2), "auto", (8, 1)),
+ ),
+ # Decoders: automatic inverse
+ "auto",
+ ),
+}
+
+
+@BackboneRegistry.register("dcunet")
+class DCUNet(nn.Module):
+ @staticmethod
+ def add_argparse_args(parser):
+ parser.add_argument(
+ "--dcunet-architecture",
+ type=str,
+ default="DilDCUNet-v2",
+ choices=DCUNET_ARCHITECTURES.keys(),
+ help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.",
+ )
+ parser.add_argument(
+ "--dcunet-time-embedding",
+ type=str,
+ choices=("gfp", "ds", "none"),
+ default="gfp",
+ help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.",
+ )
+ parser.add_argument(
+ "--dcunet-temb-layers-global",
+ type=int,
+ default=1,
+ help="Number of global linear+activation layers for the time embedding. 1 by default.",
+ )
+ parser.add_argument(
+ "--dcunet-temb-layers-local",
+ type=int,
+ default=1,
+ help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.",
+ )
+ parser.add_argument(
+ "--dcunet-temb-activation",
+ type=str,
+ default="silu",
+ help="The (complex) activation to use between all (global&local) time embedding layers.",
+ )
+ parser.add_argument(
+ "--dcunet-time-embedding-complex",
+ action="store_true",
+ help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.",
+ )
+ parser.add_argument(
+ "--dcunet-fix-length",
+ type=str,
+ default="pad",
+ choices=("pad", "trim", "none"),
+ help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.",
+ )
+ parser.add_argument(
+ "--dcunet-mask-bound",
+ type=str,
+ choices=("tanh", "sigmoid", "none"),
+ default="none",
+ help="DCUNet output bounding strategy. 'none' by default.",
+ )
+ parser.add_argument(
+ "--dcunet-norm-type",
+ type=str,
+ choices=("bN", "CbN"),
+ default="bN",
+ help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.",
+ )
+ parser.add_argument(
+ "--dcunet-activation",
+ type=str,
+ choices=("leaky_relu", "relu", "silu"),
+ default="leaky_relu",
+ help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.",
+ )
+ return parser
+
+ def __init__(
+ self,
+ dcunet_architecture: str = "DilDCUNet-v2",
+ dcunet_time_embedding: str = "gfp",
+ dcunet_temb_layers_global: int = 2,
+ dcunet_temb_layers_local: int = 1,
+ dcunet_temb_activation: str = "silu",
+ dcunet_time_embedding_complex: bool = False,
+ dcunet_fix_length: str = "pad",
+ dcunet_mask_bound: str = "none",
+ dcunet_norm_type: str = "bN",
+ dcunet_activation: str = "relu",
+ embed_dim: int = 128,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.architecture = dcunet_architecture
+ self.fix_length_mode = (
+ dcunet_fix_length if dcunet_fix_length != "none" else None
+ )
+ self.norm_type = dcunet_norm_type
+ self.activation = dcunet_activation
+ self.input_channels = 2 # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN
+ self.time_embedding = (
+ dcunet_time_embedding if dcunet_time_embedding != "none" else None
+ )
+ self.time_embedding_complex = dcunet_time_embedding_complex
+ self.temb_layers_global = dcunet_temb_layers_global
+ self.temb_layers_local = dcunet_temb_layers_local
+ self.temb_activation = dcunet_temb_activation
+ conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture]
+
+ # Replace `input_channels` in encoders config
+ _replaced_input_channels, *rest = conf_encoders[0]
+ encoders = ((self.input_channels, *rest), *conf_encoders[1:])
+ decoders = conf_decoders
+ self.encoders_stride_product = np.prod(
+ [enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0
+ )
+
+ # Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation)
+ encoder_decoder_kwargs = dict(
+ norm_type=self.norm_type,
+ activation=self.activation,
+ temb_layers=self.temb_layers_local,
+ temb_activation=self.temb_activation,
+ )
+
+ # Instantiate (global) time embedding layer
+ embed_ops = []
+ if self.time_embedding is not None:
+ complex_valued = self.time_embedding_complex
+ if self.time_embedding == "gfp":
+ embed_ops += [
+ GaussianFourierProjection(
+ embed_dim=embed_dim, complex_valued=complex_valued
+ )
+ ]
+ encoder_decoder_kwargs["embed_dim"] = embed_dim
+ elif self.time_embedding == "ds":
+ embed_ops += [
+ DiffusionStepEmbedding(
+ embed_dim=embed_dim, complex_valued=complex_valued
+ )
+ ]
+ encoder_decoder_kwargs["embed_dim"] = embed_dim
+
+ if self.time_embedding_complex:
+ assert self.time_embedding in (
+ "gfp",
+ "ds",
+ ), "Complex timestep embedding only available for gfp and ds"
+ encoder_decoder_kwargs["complex_time_embedding"] = True
+ for _ in range(self.temb_layers_global):
+ embed_ops += [
+ ComplexLinear(embed_dim, embed_dim, complex_valued=True),
+ OnReIm(get_activation(dcunet_temb_activation)),
+ ]
+ self.embed = nn.Sequential(*embed_ops)
+
+ ### Instantiate DCUNet layers ###
+ output_layer = ComplexConvTranspose2d(*decoders[-1])
+ encoders = [
+ DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs)
+ for args in encoders
+ ]
+ decoders = [
+ DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs)
+ for args in decoders[:-1]
+ ]
+
+ self.mask_bound = dcunet_mask_bound if dcunet_mask_bound != "none" else None
+ if self.mask_bound is not None:
+ raise NotImplementedError(
+ "sorry, mask bounding not implemented at the moment"
+ )
+ # TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument
+ # operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound))
+ # output_layer = nn.Sequential(*[x for x in operations if x is not None])
+
+ assert len(encoders) == len(decoders) + 1
+ self.encoders = nn.ModuleList(encoders)
+ self.decoders = nn.ModuleList(decoders)
+ self.output_layer = output_layer or nn.Identity()
+
+ def forward(self, spec, t) -> Tensor:
+ """
+ Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible
+ by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders,
+ and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time
+ strides of the encoders.
+ Args:
+ spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last.
+ Returns:
+ Tensor, of shape (batch, time) or (time).
+ """
+ # TF-rep shape: (batch, self.input_channels, n_fft, frames)
+ # Estimate mask from time-frequency representation.
+ x_in = self.fix_input_dims(spec)
+ x = x_in
+ t_embed = self.embed(t + 0j) if self.time_embedding is not None else None
+
+ enc_outs = []
+ for idx, enc in enumerate(self.encoders):
+ x = enc(x, t_embed)
+ # UNet skip connection
+ enc_outs.append(x)
+ for enc_out, dec in zip(reversed(enc_outs[:-1]), self.decoders):
+ x = dec(x, t_embed, output_size=enc_out.shape)
+ x = torch.cat([x, enc_out], dim=1)
+
+ output = self.output_layer(x, output_size=x_in.shape)
+ # output shape: (batch, 1, n_fft, frames)
+ output = self.fix_output_dims(output, spec)
+ return output
+
+ def fix_input_dims(self, x):
+ return _fix_dcu_input_dims(
+ self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product)
+ )
+
+ def fix_output_dims(self, out, x):
+ return _fix_dcu_output_dims(self.fix_length_mode, out, x)
+
+
+def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product):
+ """Pad or trim `x` to a length compatible with DCUNet."""
+ freq_prod = int(encoders_stride_product[0])
+ time_prod = int(encoders_stride_product[1])
+ if (x.shape[2] - 1) % freq_prod:
+ raise TypeError(
+ f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by "
+ f"{freq_prod}, got {x.shape} instead"
+ )
+ time_remainder = (x.shape[3] - 1) % time_prod
+ if time_remainder:
+ if fix_length_mode is None:
+ raise TypeError(
+ f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by "
+ f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument "
+ f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically."
+ )
+ elif fix_length_mode == "pad":
+ pad_shape = [0, time_prod - time_remainder]
+ x = nn.functional.pad(x, pad_shape, mode="constant")
+ elif fix_length_mode == "trim":
+ pad_shape = [0, -time_remainder]
+ x = nn.functional.pad(x, pad_shape, mode="constant")
+ else:
+ raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'")
+ return x
+
+
+def _fix_dcu_output_dims(fix_length_mode, out, x):
+ """Fix shape of `out` to the original shape of `x` by padding/cropping."""
+ inp_len = x.shape[-1]
+ output_len = out.shape[-1]
+ return nn.functional.pad(out, [0, inp_len - output_len])
+
+
+def _get_norm(norm_type):
+ if norm_type == "CbN":
+ return ComplexBatchNorm
+ elif norm_type == "bN":
+ return partial(OnReIm, BatchNorm)
+ else:
+ raise NotImplementedError(f"Unknown norm type: {norm_type}")
+
+
+class DCUNetComplexEncoderBlock(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ out_chan,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ norm_type="bN",
+ activation="leaky_relu",
+ embed_dim=None,
+ complex_time_embedding=False,
+ temb_layers=1,
+ temb_activation="silu",
+ ):
+ super().__init__()
+
+ self.in_chan = in_chan
+ self.out_chan = out_chan
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.temb_layers = temb_layers
+ self.temb_activation = temb_activation
+ self.complex_time_embedding = complex_time_embedding
+
+ self.conv = ComplexConv2d(
+ in_chan,
+ out_chan,
+ kernel_size,
+ stride,
+ padding,
+ bias=norm_type is None,
+ dilation=dilation,
+ )
+ self.norm = _get_norm(norm_type)(out_chan)
+ self.activation = OnReIm(get_activation(activation))
+ self.embed_dim = embed_dim
+ if self.embed_dim is not None:
+ ops = []
+ for _ in range(max(0, self.temb_layers - 1)):
+ ops += [
+ ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
+ OnReIm(get_activation(self.temb_activation)),
+ ]
+ ops += [
+ FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
+ OnReIm(get_activation(self.temb_activation)),
+ ]
+ self.embed_layer = nn.Sequential(*ops)
+
+ def forward(self, x, t_embed):
+ y = self.conv(x)
+ if self.embed_dim is not None:
+ y = y + self.embed_layer(t_embed)
+ return self.activation(self.norm(y))
+
+
+class DCUNetComplexDecoderBlock(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ out_chan,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ output_padding=(0, 0),
+ norm_type="bN",
+ activation="leaky_relu",
+ embed_dim=None,
+ temb_layers=1,
+ temb_activation="swish",
+ complex_time_embedding=False,
+ ):
+ super().__init__()
+
+ self.in_chan = in_chan
+ self.out_chan = out_chan
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.output_padding = output_padding
+ self.complex_time_embedding = complex_time_embedding
+ self.temb_layers = temb_layers
+ self.temb_activation = temb_activation
+
+ self.deconv = ComplexConvTranspose2d(
+ in_chan,
+ out_chan,
+ kernel_size,
+ stride,
+ padding,
+ output_padding,
+ dilation=dilation,
+ bias=norm_type is None,
+ )
+ self.norm = _get_norm(norm_type)(out_chan)
+ self.activation = OnReIm(get_activation(activation))
+ self.embed_dim = embed_dim
+ if self.embed_dim is not None:
+ ops = []
+ for _ in range(max(0, self.temb_layers - 1)):
+ ops += [
+ ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
+ OnReIm(get_activation(self.temb_activation)),
+ ]
+ ops += [
+ FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
+ OnReIm(get_activation(self.temb_activation)),
+ ]
+ self.embed_layer = nn.Sequential(*ops)
+
+ def forward(self, x, t_embed, output_size=None):
+ y = self.deconv(x, output_size=output_size)
+ if self.embed_dim is not None:
+ y = y + self.embed_layer(t_embed)
+ return self.activation(self.norm(y))
+
+
+# From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py
+class ComplexBatchNorm(torch.nn.Module):
+ def __init__(
+ self,
+ num_features,
+ eps=1e-5,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=False,
+ ):
+ super(ComplexBatchNorm, self).__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.track_running_stats = track_running_stats
+ if self.affine:
+ self.Wrr = torch.nn.Parameter(torch.Tensor(num_features))
+ self.Wri = torch.nn.Parameter(torch.Tensor(num_features))
+ self.Wii = torch.nn.Parameter(torch.Tensor(num_features))
+ self.Br = torch.nn.Parameter(torch.Tensor(num_features))
+ self.Bi = torch.nn.Parameter(torch.Tensor(num_features))
+ else:
+ self.register_parameter("Wrr", None)
+ self.register_parameter("Wri", None)
+ self.register_parameter("Wii", None)
+ self.register_parameter("Br", None)
+ self.register_parameter("Bi", None)
+ if self.track_running_stats:
+ self.register_buffer("RMr", torch.zeros(num_features))
+ self.register_buffer("RMi", torch.zeros(num_features))
+ self.register_buffer("RVrr", torch.ones(num_features))
+ self.register_buffer("RVri", torch.zeros(num_features))
+ self.register_buffer("RVii", torch.ones(num_features))
+ self.register_buffer(
+ "num_batches_tracked", torch.tensor(0, dtype=torch.long)
+ )
+ else:
+ self.register_parameter("RMr", None)
+ self.register_parameter("RMi", None)
+ self.register_parameter("RVrr", None)
+ self.register_parameter("RVri", None)
+ self.register_parameter("RVii", None)
+ self.register_parameter("num_batches_tracked", None)
+ self.reset_parameters()
+
+ def reset_running_stats(self):
+ if self.track_running_stats:
+ self.RMr.zero_()
+ self.RMi.zero_()
+ self.RVrr.fill_(1)
+ self.RVri.zero_()
+ self.RVii.fill_(1)
+ self.num_batches_tracked.zero_()
+
+ def reset_parameters(self):
+ self.reset_running_stats()
+ if self.affine:
+ self.Br.data.zero_()
+ self.Bi.data.zero_()
+ self.Wrr.data.fill_(1)
+ self.Wri.data.uniform_(-0.9, +0.9) # W will be positive-definite
+ self.Wii.data.fill_(1)
+
+ def _check_input_dim(self, xr, xi):
+ assert xr.shape == xi.shape
+ assert xr.size(1) == self.num_features
+
+ def forward(self, x):
+ xr, xi = x.real, x.imag
+ self._check_input_dim(xr, xi)
+
+ exponential_average_factor = 0.0
+
+ if self.training and self.track_running_stats:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ #
+ # NOTE: The precise meaning of the "training flag" is:
+ # True: Normalize using batch statistics, update running statistics
+ # if they are being collected.
+ # False: Normalize using running statistics, ignore batch statistics.
+ #
+ training = self.training or not self.track_running_stats
+ redux = [i for i in reversed(range(xr.dim())) if i != 1]
+ vdim = [1] * xr.dim()
+ vdim[1] = xr.size(1)
+
+ #
+ # Mean M Computation and Centering
+ #
+ # Includes running mean update if training and running.
+ #
+ if training:
+ Mr, Mi = xr, xi
+ for d in redux:
+ Mr = Mr.mean(d, keepdim=True)
+ Mi = Mi.mean(d, keepdim=True)
+ if self.track_running_stats:
+ self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
+ self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
+ else:
+ Mr = self.RMr.view(vdim)
+ Mi = self.RMi.view(vdim)
+ xr, xi = xr - Mr, xi - Mi
+
+ #
+ # Variance Matrix V Computation
+ #
+ # Includes epsilon numerical stabilizer/Tikhonov regularizer.
+ # Includes running variance update if training and running.
+ #
+ if training:
+ Vrr = xr * xr
+ Vri = xr * xi
+ Vii = xi * xi
+ for d in redux:
+ Vrr = Vrr.mean(d, keepdim=True)
+ Vri = Vri.mean(d, keepdim=True)
+ Vii = Vii.mean(d, keepdim=True)
+ if self.track_running_stats:
+ self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
+ self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
+ self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
+ else:
+ Vrr = self.RVrr.view(vdim)
+ Vri = self.RVri.view(vdim)
+ Vii = self.RVii.view(vdim)
+ Vrr = Vrr + self.eps
+ Vri = Vri
+ Vii = Vii + self.eps
+
+ #
+ # Matrix Inverse Square Root U = V^-0.5
+ #
+ # sqrt of a 2x2 matrix,
+ # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
+ tau = Vrr + Vii
+ delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
+ s = delta.sqrt()
+ t = (tau + 2 * s).sqrt()
+
+ # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
+ rst = (s * t).reciprocal()
+ Urr = (s + Vii) * rst
+ Uii = (s + Vrr) * rst
+ Uri = (-Vri) * rst
+
+ #
+ # Optionally left-multiply U by affine weights W to produce combined
+ # weights Z, left-multiply the inputs by Z, then optionally bias them.
+ #
+ # y = Zx + B
+ # y = WUx + B
+ # y = [Wrr Wri][Urr Uri] [xr] + [Br]
+ # [Wir Wii][Uir Uii] [xi] [Bi]
+ #
+ if self.affine:
+ Wrr, Wri, Wii = (
+ self.Wrr.view(vdim),
+ self.Wri.view(vdim),
+ self.Wii.view(vdim),
+ )
+ Zrr = (Wrr * Urr) + (Wri * Uri)
+ Zri = (Wrr * Uri) + (Wri * Uii)
+ Zir = (Wri * Urr) + (Wii * Uri)
+ Zii = (Wri * Uri) + (Wii * Uii)
+ else:
+ Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
+
+ yr = (Zrr * xr) + (Zri * xi)
+ yi = (Zir * xr) + (Zii * xi)
+
+ if self.affine:
+ yr = yr + self.Br.view(vdim)
+ yi = yi + self.Bi.view(vdim)
+
+ return torch.view_as_complex(torch.stack([yr, yi], dim=-1))
+
+ def extra_repr(self):
+ return (
+ "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
+ "track_running_stats={track_running_stats}".format(**self.__dict__)
+ )
diff --git a/modules/sgmse/ncsnpp.py b/modules/sgmse/ncsnpp.py
new file mode 100644
index 00000000..4302f47a
--- /dev/null
+++ b/modules/sgmse/ncsnpp.py
@@ -0,0 +1,497 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+
+from .ncsnpp_utils import layers, layerspp, normalization
+import torch.nn as nn
+import functools
+import torch
+import numpy as np
+
+from .shared import BackboneRegistry
+
+ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
+ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
+Combine = layerspp.Combine
+conv3x3 = layerspp.conv3x3
+conv1x1 = layerspp.conv1x1
+get_act = layers.get_act
+get_normalization = normalization.get_normalization
+default_initializer = layers.default_init
+
+
+@BackboneRegistry.register("ncsnpp")
+class NCSNpp(nn.Module):
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
+
+ @staticmethod
+ def add_argparse_args(parser):
+ parser.add_argument(
+ "--ch_mult", type=int, nargs="+", default=[1, 1, 2, 2, 2, 2, 2]
+ )
+ parser.add_argument("--num_res_blocks", type=int, default=2)
+ parser.add_argument("--attn_resolutions", type=int, nargs="+", default=[16])
+ parser.add_argument(
+ "--no-centered",
+ dest="centered",
+ action="store_false",
+ help="The data is not centered [-1, 1]",
+ )
+ parser.add_argument(
+ "--centered",
+ dest="centered",
+ action="store_true",
+ help="The data is centered [-1, 1]",
+ )
+ parser.set_defaults(centered=True)
+ return parser
+
+ def __init__(
+ self,
+ scale_by_sigma=True,
+ nonlinearity="swish",
+ nf=128,
+ ch_mult=(1, 1, 2, 2, 2, 2, 2),
+ num_res_blocks=2,
+ attn_resolutions=(16,),
+ resamp_with_conv=True,
+ conditional=True,
+ fir=True,
+ fir_kernel=[1, 3, 3, 1],
+ skip_rescale=True,
+ resblock_type="biggan",
+ progressive="output_skip",
+ progressive_input="input_skip",
+ progressive_combine="sum",
+ init_scale=0.0,
+ fourier_scale=16,
+ image_size=256,
+ embedding_type="fourier",
+ dropout=0.0,
+ centered=True,
+ **unused_kwargs,
+ ):
+ super().__init__()
+ self.act = act = get_act(nonlinearity)
+
+ self.nf = nf = nf
+ ch_mult = ch_mult
+ self.num_res_blocks = num_res_blocks = num_res_blocks
+ self.attn_resolutions = attn_resolutions = attn_resolutions
+ dropout = dropout
+ resamp_with_conv = resamp_with_conv
+ self.num_resolutions = num_resolutions = len(ch_mult)
+ self.all_resolutions = all_resolutions = [
+ image_size // (2**i) for i in range(num_resolutions)
+ ]
+
+ self.conditional = conditional = conditional # noise-conditional
+ self.centered = centered
+ self.scale_by_sigma = scale_by_sigma
+
+ fir = fir
+ fir_kernel = fir_kernel
+ self.skip_rescale = skip_rescale = skip_rescale
+ self.resblock_type = resblock_type = resblock_type.lower()
+ self.progressive = progressive = progressive.lower()
+ self.progressive_input = progressive_input = progressive_input.lower()
+ self.embedding_type = embedding_type = embedding_type.lower()
+ init_scale = init_scale
+ assert progressive in ["none", "output_skip", "residual"]
+ assert progressive_input in ["none", "input_skip", "residual"]
+ assert embedding_type in ["fourier", "positional"]
+ combine_method = progressive_combine.lower()
+ combiner = functools.partial(Combine, method=combine_method)
+
+ num_channels = 4 # x.real, x.imag, y.real, y.imag
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
+
+ modules = []
+ # timestep/noise_level embedding
+ if embedding_type == "fourier":
+ # Gaussian Fourier features embeddings.
+ modules.append(
+ layerspp.GaussianFourierProjection(
+ embedding_size=nf, scale=fourier_scale
+ )
+ )
+ embed_dim = 2 * nf
+ elif embedding_type == "positional":
+ embed_dim = nf
+ else:
+ raise ValueError(f"embedding type {embedding_type} unknown.")
+
+ if conditional:
+ modules.append(nn.Linear(embed_dim, nf * 4))
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
+ nn.init.zeros_(modules[-1].bias)
+ modules.append(nn.Linear(nf * 4, nf * 4))
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
+ nn.init.zeros_(modules[-1].bias)
+
+ AttnBlock = functools.partial(
+ layerspp.AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale
+ )
+
+ Upsample = functools.partial(
+ layerspp.Upsample,
+ with_conv=resamp_with_conv,
+ fir=fir,
+ fir_kernel=fir_kernel,
+ )
+
+ if progressive == "output_skip":
+ self.pyramid_upsample = layerspp.Upsample(
+ fir=fir, fir_kernel=fir_kernel, with_conv=False
+ )
+ elif progressive == "residual":
+ pyramid_upsample = functools.partial(
+ layerspp.Upsample, fir=fir, fir_kernel=fir_kernel, with_conv=True
+ )
+
+ Downsample = functools.partial(
+ layerspp.Downsample,
+ with_conv=resamp_with_conv,
+ fir=fir,
+ fir_kernel=fir_kernel,
+ )
+
+ if progressive_input == "input_skip":
+ self.pyramid_downsample = layerspp.Downsample(
+ fir=fir, fir_kernel=fir_kernel, with_conv=False
+ )
+ elif progressive_input == "residual":
+ pyramid_downsample = functools.partial(
+ layerspp.Downsample, fir=fir, fir_kernel=fir_kernel, with_conv=True
+ )
+
+ if resblock_type == "ddpm":
+ ResnetBlock = functools.partial(
+ ResnetBlockDDPM,
+ act=act,
+ dropout=dropout,
+ init_scale=init_scale,
+ skip_rescale=skip_rescale,
+ temb_dim=nf * 4,
+ )
+
+ elif resblock_type == "biggan":
+ ResnetBlock = functools.partial(
+ ResnetBlockBigGAN,
+ act=act,
+ dropout=dropout,
+ fir=fir,
+ fir_kernel=fir_kernel,
+ init_scale=init_scale,
+ skip_rescale=skip_rescale,
+ temb_dim=nf * 4,
+ )
+
+ else:
+ raise ValueError(f"resblock type {resblock_type} unrecognized.")
+
+ # Downsampling block
+
+ channels = num_channels
+ if progressive_input != "none":
+ input_pyramid_ch = channels
+
+ modules.append(conv3x3(channels, nf))
+ hs_c = [nf]
+
+ in_ch = nf
+ for i_level in range(num_resolutions):
+ # Residual blocks for this resolution
+ for i_block in range(num_res_blocks):
+ out_ch = nf * ch_mult[i_level]
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
+ in_ch = out_ch
+
+ if all_resolutions[i_level] in attn_resolutions:
+ modules.append(AttnBlock(channels=in_ch))
+ hs_c.append(in_ch)
+
+ if i_level != num_resolutions - 1:
+ if resblock_type == "ddpm":
+ modules.append(Downsample(in_ch=in_ch))
+ else:
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
+
+ if progressive_input == "input_skip":
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
+ if combine_method == "cat":
+ in_ch *= 2
+
+ elif progressive_input == "residual":
+ modules.append(
+ pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)
+ )
+ input_pyramid_ch = in_ch
+
+ hs_c.append(in_ch)
+
+ in_ch = hs_c[-1]
+ modules.append(ResnetBlock(in_ch=in_ch))
+ modules.append(AttnBlock(channels=in_ch))
+ modules.append(ResnetBlock(in_ch=in_ch))
+
+ pyramid_ch = 0
+ # Upsampling block
+ for i_level in reversed(range(num_resolutions)):
+ for i_block in range(
+ num_res_blocks + 1
+ ): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
+ out_ch = nf * ch_mult[i_level]
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
+ in_ch = out_ch
+
+ if all_resolutions[i_level] in attn_resolutions:
+ modules.append(AttnBlock(channels=in_ch))
+
+ if progressive != "none":
+ if i_level == num_resolutions - 1:
+ if progressive == "output_skip":
+ modules.append(
+ nn.GroupNorm(
+ num_groups=min(in_ch // 4, 32),
+ num_channels=in_ch,
+ eps=1e-6,
+ )
+ )
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
+ pyramid_ch = channels
+ elif progressive == "residual":
+ modules.append(
+ nn.GroupNorm(
+ num_groups=min(in_ch // 4, 32),
+ num_channels=in_ch,
+ eps=1e-6,
+ )
+ )
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
+ pyramid_ch = in_ch
+ else:
+ raise ValueError(f"{progressive} is not a valid name.")
+ else:
+ if progressive == "output_skip":
+ modules.append(
+ nn.GroupNorm(
+ num_groups=min(in_ch // 4, 32),
+ num_channels=in_ch,
+ eps=1e-6,
+ )
+ )
+ modules.append(
+ conv3x3(in_ch, channels, bias=True, init_scale=init_scale)
+ )
+ pyramid_ch = channels
+ elif progressive == "residual":
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
+ pyramid_ch = in_ch
+ else:
+ raise ValueError(f"{progressive} is not a valid name")
+
+ if i_level != 0:
+ if resblock_type == "ddpm":
+ modules.append(Upsample(in_ch=in_ch))
+ else:
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
+
+ assert not hs_c
+
+ if progressive != "output_skip":
+ modules.append(
+ nn.GroupNorm(
+ num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6
+ )
+ )
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
+
+ self.all_modules = nn.ModuleList(modules)
+
+ def forward(self, x, time_cond):
+ # timestep/noise_level embedding; only for continuous training
+ modules = self.all_modules
+ m_idx = 0
+
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
+ x = torch.cat(
+ (
+ x[:, [0], :, :].real,
+ x[:, [0], :, :].imag,
+ x[:, [1], :, :].real,
+ x[:, [1], :, :].imag,
+ ),
+ dim=1,
+ )
+
+ if self.embedding_type == "fourier":
+ # Gaussian Fourier features embeddings.
+ used_sigmas = time_cond
+ temb = modules[m_idx](torch.log(used_sigmas))
+ m_idx += 1
+
+ elif self.embedding_type == "positional":
+ # Sinusoidal positional embeddings.
+ timesteps = time_cond
+ used_sigmas = self.sigmas[time_cond.long()]
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
+
+ else:
+ raise ValueError(f"embedding type {self.embedding_type} unknown.")
+
+ if self.conditional:
+ temb = modules[m_idx](temb)
+ m_idx += 1
+ temb = modules[m_idx](self.act(temb))
+ m_idx += 1
+ else:
+ temb = None
+
+ if not self.centered:
+ # If input data is in [0, 1]
+ x = 2 * x - 1.0
+
+ # Downsampling block
+ input_pyramid = None
+ if self.progressive_input != "none":
+ input_pyramid = x
+
+ # Input layer: Conv2d: 4ch -> 128ch
+ hs = [modules[m_idx](x)]
+ m_idx += 1
+
+ # Down path in U-Net
+ for i_level in range(self.num_resolutions):
+ # Residual blocks for this resolution
+ for i_block in range(self.num_res_blocks):
+ h = modules[m_idx](hs[-1], temb)
+ m_idx += 1
+ # Attention layer (optional)
+ if (
+ h.shape[-2] in self.attn_resolutions
+ ): # edit: check H dim (-2) not W dim (-1)
+ h = modules[m_idx](h)
+ m_idx += 1
+ hs.append(h)
+
+ # Downsampling
+ if i_level != self.num_resolutions - 1:
+ if self.resblock_type == "ddpm":
+ h = modules[m_idx](hs[-1])
+ m_idx += 1
+ else:
+ h = modules[m_idx](hs[-1], temb)
+ m_idx += 1
+
+ if self.progressive_input == "input_skip": # Combine h with x
+ input_pyramid = self.pyramid_downsample(input_pyramid)
+ h = modules[m_idx](input_pyramid, h)
+ m_idx += 1
+
+ elif self.progressive_input == "residual":
+ input_pyramid = modules[m_idx](input_pyramid)
+ m_idx += 1
+ if self.skip_rescale:
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.0)
+ else:
+ input_pyramid = input_pyramid + h
+ h = input_pyramid
+ hs.append(h)
+
+ h = hs[-1] # actualy equal to: h = h
+ h = modules[m_idx](h, temb) # ResNet block
+ m_idx += 1
+ h = modules[m_idx](h) # Attention block
+ m_idx += 1
+ h = modules[m_idx](h, temb) # ResNet block
+ m_idx += 1
+
+ pyramid = None
+
+ # Upsampling block
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
+ m_idx += 1
+
+ # edit: from -1 to -2
+ if h.shape[-2] in self.attn_resolutions:
+ h = modules[m_idx](h)
+ m_idx += 1
+
+ if self.progressive != "none":
+ if i_level == self.num_resolutions - 1:
+ if self.progressive == "output_skip":
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
+ m_idx += 1
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
+ m_idx += 1
+ elif self.progressive == "residual":
+ pyramid = self.act(modules[m_idx](h))
+ m_idx += 1
+ pyramid = modules[m_idx](pyramid)
+ m_idx += 1
+ else:
+ raise ValueError(f"{self.progressive} is not a valid name.")
+ else:
+ if self.progressive == "output_skip":
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
+ m_idx += 1
+ pyramid_h = modules[m_idx](pyramid_h)
+ m_idx += 1
+ pyramid = pyramid + pyramid_h
+ elif self.progressive == "residual":
+ pyramid = modules[m_idx](pyramid)
+ m_idx += 1
+ if self.skip_rescale:
+ pyramid = (pyramid + h) / np.sqrt(2.0)
+ else:
+ pyramid = pyramid + h
+ h = pyramid
+ else:
+ raise ValueError(f"{self.progressive} is not a valid name")
+
+ # Upsampling Layer
+ if i_level != 0:
+ if self.resblock_type == "ddpm":
+ h = modules[m_idx](h)
+ m_idx += 1
+ else:
+ h = modules[m_idx](h, temb) # Upspampling
+ m_idx += 1
+
+ assert not hs
+
+ if self.progressive == "output_skip":
+ h = pyramid
+ else:
+ h = self.act(modules[m_idx](h))
+ m_idx += 1
+ h = modules[m_idx](h)
+ m_idx += 1
+
+ assert m_idx == len(modules), "Implementation error"
+ if self.scale_by_sigma:
+ used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
+ h = h / used_sigmas
+
+ # Convert back to complex number
+ h = self.output_layer(h)
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
+ h = torch.view_as_complex(h)[:, None, :, :]
+ return h
diff --git a/modules/sgmse/ncsnpp_utils/layers.py b/modules/sgmse/ncsnpp_utils/layers.py
new file mode 100644
index 00000000..76bf8ac3
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/layers.py
@@ -0,0 +1,800 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+"""Common layers for defining score networks.
+"""
+import math
+import string
+from functools import partial
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+import numpy as np
+from .normalization import ConditionalInstanceNorm2dPlus
+
+
+def get_act(config):
+ """Get activation functions from the config file."""
+
+ if config == "elu":
+ return nn.ELU()
+ elif config == "relu":
+ return nn.ReLU()
+ elif config == "lrelu":
+ return nn.LeakyReLU(negative_slope=0.2)
+ elif config == "swish":
+ return nn.SiLU()
+ else:
+ raise NotImplementedError("activation function does not exist!")
+
+
+def ncsn_conv1x1(
+ in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=0
+):
+ """1x1 convolution. Same as NCSNv1/v2."""
+ conv = nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=bias,
+ dilation=dilation,
+ padding=padding,
+ )
+ init_scale = 1e-10 if init_scale == 0 else init_scale
+ conv.weight.data *= init_scale
+ conv.bias.data *= init_scale
+ return conv
+
+
+def variance_scaling(
+ scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"
+):
+ """Ported from JAX."""
+
+ def _compute_fans(shape, in_axis=1, out_axis=0):
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
+ fan_in = shape[in_axis] * receptive_field_size
+ fan_out = shape[out_axis] * receptive_field_size
+ return fan_in, fan_out
+
+ def init(shape, dtype=dtype, device=device):
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
+ if mode == "fan_in":
+ denominator = fan_in
+ elif mode == "fan_out":
+ denominator = fan_out
+ elif mode == "fan_avg":
+ denominator = (fan_in + fan_out) / 2
+ else:
+ raise ValueError(
+ "invalid mode for variance scaling initializer: {}".format(mode)
+ )
+ variance = scale / denominator
+ if distribution == "normal":
+ return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
+ elif distribution == "uniform":
+ return (
+ torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0
+ ) * np.sqrt(3 * variance)
+ else:
+ raise ValueError("invalid distribution for variance scaling initializer")
+
+ return init
+
+
+def default_init(scale=1.0):
+ """The same initialization used in DDPM."""
+ scale = 1e-10 if scale == 0 else scale
+ return variance_scaling(scale, "fan_avg", "uniform")
+
+
+class Dense(nn.Module):
+ """Linear layer with `default_init`."""
+
+ def __init__(self):
+ super().__init__()
+
+
+def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
+ """1x1 convolution with DDPM initialization."""
+ conv = nn.Conv2d(
+ in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias
+ )
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
+ nn.init.zeros_(conv.bias)
+ return conv
+
+
+def ncsn_conv3x3(
+ in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1
+):
+ """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
+ init_scale = 1e-10 if init_scale == 0 else init_scale
+ conv = nn.Conv2d(
+ in_planes,
+ out_planes,
+ stride=stride,
+ bias=bias,
+ dilation=dilation,
+ padding=padding,
+ kernel_size=3,
+ )
+ conv.weight.data *= init_scale
+ conv.bias.data *= init_scale
+ return conv
+
+
+def ddpm_conv3x3(
+ in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1
+):
+ """3x3 convolution with DDPM initialization."""
+ conv = nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
+ nn.init.zeros_(conv.bias)
+ return conv
+
+ ###########################################################################
+ # Functions below are ported over from the NCSNv1/NCSNv2 codebase:
+ # https://github.com/ermongroup/ncsn
+ # https://github.com/ermongroup/ncsnv2
+ ###########################################################################
+
+
+class CRPBlock(nn.Module):
+ def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ for i in range(n_stages):
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
+ self.n_stages = n_stages
+ if maxpool:
+ self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
+ else:
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
+
+ self.act = act
+
+ def forward(self, x):
+ x = self.act(x)
+ path = x
+ for i in range(self.n_stages):
+ path = self.pool(path)
+ path = self.convs[i](path)
+ x = path + x
+ return x
+
+
+class CondCRPBlock(nn.Module):
+ def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ self.norms = nn.ModuleList()
+ self.normalizer = normalizer
+ for i in range(n_stages):
+ self.norms.append(normalizer(features, num_classes, bias=True))
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
+
+ self.n_stages = n_stages
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
+ self.act = act
+
+ def forward(self, x, y):
+ x = self.act(x)
+ path = x
+ for i in range(self.n_stages):
+ path = self.norms[i](path, y)
+ path = self.pool(path)
+ path = self.convs[i](path)
+
+ x = path + x
+ return x
+
+
+class RCUBlock(nn.Module):
+ def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
+ super().__init__()
+
+ for i in range(n_blocks):
+ for j in range(n_stages):
+ setattr(
+ self,
+ "{}_{}_conv".format(i + 1, j + 1),
+ ncsn_conv3x3(features, features, stride=1, bias=False),
+ )
+
+ self.stride = 1
+ self.n_blocks = n_blocks
+ self.n_stages = n_stages
+ self.act = act
+
+ def forward(self, x):
+ for i in range(self.n_blocks):
+ residual = x
+ for j in range(self.n_stages):
+ x = self.act(x)
+ x = getattr(self, "{}_{}_conv".format(i + 1, j + 1))(x)
+
+ x += residual
+ return x
+
+
+class CondRCUBlock(nn.Module):
+ def __init__(
+ self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()
+ ):
+ super().__init__()
+
+ for i in range(n_blocks):
+ for j in range(n_stages):
+ setattr(
+ self,
+ "{}_{}_norm".format(i + 1, j + 1),
+ normalizer(features, num_classes, bias=True),
+ )
+ setattr(
+ self,
+ "{}_{}_conv".format(i + 1, j + 1),
+ ncsn_conv3x3(features, features, stride=1, bias=False),
+ )
+
+ self.stride = 1
+ self.n_blocks = n_blocks
+ self.n_stages = n_stages
+ self.act = act
+ self.normalizer = normalizer
+
+ def forward(self, x, y):
+ for i in range(self.n_blocks):
+ residual = x
+ for j in range(self.n_stages):
+ x = getattr(self, "{}_{}_norm".format(i + 1, j + 1))(x, y)
+ x = self.act(x)
+ x = getattr(self, "{}_{}_conv".format(i + 1, j + 1))(x)
+
+ x += residual
+ return x
+
+
+class MSFBlock(nn.Module):
+ def __init__(self, in_planes, features):
+ super().__init__()
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
+ self.convs = nn.ModuleList()
+ self.features = features
+
+ for i in range(len(in_planes)):
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
+
+ def forward(self, xs, shape):
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
+ for i in range(len(self.convs)):
+ h = self.convs[i](xs[i])
+ h = F.interpolate(h, size=shape, mode="bilinear", align_corners=True)
+ sums += h
+ return sums
+
+
+class CondMSFBlock(nn.Module):
+ def __init__(self, in_planes, features, num_classes, normalizer):
+ super().__init__()
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
+
+ self.convs = nn.ModuleList()
+ self.norms = nn.ModuleList()
+ self.features = features
+ self.normalizer = normalizer
+
+ for i in range(len(in_planes)):
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
+ self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
+
+ def forward(self, xs, y, shape):
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
+ for i in range(len(self.convs)):
+ h = self.norms[i](xs[i], y)
+ h = self.convs[i](h)
+ h = F.interpolate(h, size=shape, mode="bilinear", align_corners=True)
+ sums += h
+ return sums
+
+
+class RefineBlock(nn.Module):
+ def __init__(
+ self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True
+ ):
+ super().__init__()
+
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
+ self.n_blocks = n_blocks = len(in_planes)
+
+ self.adapt_convs = nn.ModuleList()
+ for i in range(n_blocks):
+ self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
+
+ self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
+
+ if not start:
+ self.msf = MSFBlock(in_planes, features)
+
+ self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
+
+ def forward(self, xs, output_shape):
+ assert isinstance(xs, tuple) or isinstance(xs, list)
+ hs = []
+ for i in range(len(xs)):
+ h = self.adapt_convs[i](xs[i])
+ hs.append(h)
+
+ if self.n_blocks > 1:
+ h = self.msf(hs, output_shape)
+ else:
+ h = hs[0]
+
+ h = self.crp(h)
+ h = self.output_convs(h)
+
+ return h
+
+
+class CondRefineBlock(nn.Module):
+ def __init__(
+ self,
+ in_planes,
+ features,
+ num_classes,
+ normalizer,
+ act=nn.ReLU(),
+ start=False,
+ end=False,
+ ):
+ super().__init__()
+
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
+ self.n_blocks = n_blocks = len(in_planes)
+
+ self.adapt_convs = nn.ModuleList()
+ for i in range(n_blocks):
+ self.adapt_convs.append(
+ CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
+ )
+
+ self.output_convs = CondRCUBlock(
+ features, 3 if end else 1, 2, num_classes, normalizer, act
+ )
+
+ if not start:
+ self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
+
+ self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
+
+ def forward(self, xs, y, output_shape):
+ assert isinstance(xs, tuple) or isinstance(xs, list)
+ hs = []
+ for i in range(len(xs)):
+ h = self.adapt_convs[i](xs[i], y)
+ hs.append(h)
+
+ if self.n_blocks > 1:
+ h = self.msf(hs, y, output_shape)
+ else:
+ h = hs[0]
+
+ h = self.crp(h, y)
+ h = self.output_convs(h, y)
+
+ return h
+
+
+class ConvMeanPool(nn.Module):
+ def __init__(
+ self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False
+ ):
+ super().__init__()
+ if not adjust_padding:
+ conv = nn.Conv2d(
+ input_dim,
+ output_dim,
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ bias=biases,
+ )
+ self.conv = conv
+ else:
+ conv = nn.Conv2d(
+ input_dim,
+ output_dim,
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ bias=biases,
+ )
+
+ self.conv = nn.Sequential(nn.ZeroPad2d((1, 0, 1, 0)), conv)
+
+ def forward(self, inputs):
+ output = self.conv(inputs)
+ output = (
+ sum(
+ [
+ output[:, :, ::2, ::2],
+ output[:, :, 1::2, ::2],
+ output[:, :, ::2, 1::2],
+ output[:, :, 1::2, 1::2],
+ ]
+ )
+ / 4.0
+ )
+ return output
+
+
+class MeanPoolConv(nn.Module):
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ input_dim,
+ output_dim,
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ bias=biases,
+ )
+
+ def forward(self, inputs):
+ output = inputs
+ output = (
+ sum(
+ [
+ output[:, :, ::2, ::2],
+ output[:, :, 1::2, ::2],
+ output[:, :, ::2, 1::2],
+ output[:, :, 1::2, 1::2],
+ ]
+ )
+ / 4.0
+ )
+ return self.conv(output)
+
+
+class UpsampleConv(nn.Module):
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ input_dim,
+ output_dim,
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ bias=biases,
+ )
+ self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
+
+ def forward(self, inputs):
+ output = inputs
+ output = torch.cat([output, output, output, output], dim=1)
+ output = self.pixelshuffle(output)
+ return self.conv(output)
+
+
+class ConditionalResidualBlock(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ output_dim,
+ num_classes,
+ resample=1,
+ act=nn.ELU(),
+ normalization=ConditionalInstanceNorm2dPlus,
+ adjust_padding=False,
+ dilation=None,
+ ):
+ super().__init__()
+ self.non_linearity = act
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.resample = resample
+ self.normalization = normalization
+ if resample == "down":
+ if dilation > 1:
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
+ self.normalize2 = normalization(input_dim, num_classes)
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
+ else:
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
+ self.normalize2 = normalization(input_dim, num_classes)
+ self.conv2 = ConvMeanPool(
+ input_dim, output_dim, 3, adjust_padding=adjust_padding
+ )
+ conv_shortcut = partial(
+ ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding
+ )
+
+ elif resample is None:
+ if dilation > 1:
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
+ self.normalize2 = normalization(output_dim, num_classes)
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
+ else:
+ conv_shortcut = nn.Conv2d
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
+ self.normalize2 = normalization(output_dim, num_classes)
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
+ else:
+ raise Exception("invalid resample value")
+
+ if output_dim != input_dim or resample is not None:
+ self.shortcut = conv_shortcut(input_dim, output_dim)
+
+ self.normalize1 = normalization(input_dim, num_classes)
+
+ def forward(self, x, y):
+ output = self.normalize1(x, y)
+ output = self.non_linearity(output)
+ output = self.conv1(output)
+ output = self.normalize2(output, y)
+ output = self.non_linearity(output)
+ output = self.conv2(output)
+
+ if self.output_dim == self.input_dim and self.resample is None:
+ shortcut = x
+ else:
+ shortcut = self.shortcut(x)
+
+ return shortcut + output
+
+
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ output_dim,
+ resample=None,
+ act=nn.ELU(),
+ normalization=nn.InstanceNorm2d,
+ adjust_padding=False,
+ dilation=1,
+ ):
+ super().__init__()
+ self.non_linearity = act
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.resample = resample
+ self.normalization = normalization
+ if resample == "down":
+ if dilation > 1:
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
+ self.normalize2 = normalization(input_dim)
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
+ else:
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
+ self.normalize2 = normalization(input_dim)
+ self.conv2 = ConvMeanPool(
+ input_dim, output_dim, 3, adjust_padding=adjust_padding
+ )
+ conv_shortcut = partial(
+ ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding
+ )
+
+ elif resample is None:
+ if dilation > 1:
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
+ self.normalize2 = normalization(output_dim)
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
+ else:
+ # conv_shortcut = nn.Conv2d ### Something wierd here.
+ conv_shortcut = partial(ncsn_conv1x1)
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
+ self.normalize2 = normalization(output_dim)
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
+ else:
+ raise Exception("invalid resample value")
+
+ if output_dim != input_dim or resample is not None:
+ self.shortcut = conv_shortcut(input_dim, output_dim)
+
+ self.normalize1 = normalization(input_dim)
+
+ def forward(self, x):
+ output = self.normalize1(x)
+ output = self.non_linearity(output)
+ output = self.conv1(output)
+ output = self.normalize2(output)
+ output = self.non_linearity(output)
+ output = self.conv2(output)
+
+ if self.output_dim == self.input_dim and self.resample is None:
+ shortcut = x
+ else:
+ shortcut = self.shortcut(x)
+
+ return shortcut + output
+
+
+###########################################################################
+# Functions below are ported over from the DDPM codebase:
+# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
+###########################################################################
+
+
+def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
+ assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
+ half_dim = embedding_dim // 2
+ # magic number 10000 is from transformers
+ emb = math.log(max_positions) / (half_dim - 1)
+ # emb = math.log(2.) / (half_dim - 1)
+ emb = torch.exp(
+ torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb
+ )
+ # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
+ # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = F.pad(emb, (0, 1), mode="constant")
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
+ return emb
+
+
+def _einsum(a, b, c, x, y):
+ einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
+ return torch.einsum(einsum_str, x, y)
+
+
+def contract_inner(x, y):
+ """tensordot(x, y, 1)."""
+ x_chars = list(string.ascii_lowercase[: len(x.shape)])
+ y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
+ y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
+ out_chars = x_chars[:-1] + y_chars[1:]
+ return _einsum(x_chars, y_chars, out_chars, x, y)
+
+
+class NIN(nn.Module):
+ def __init__(self, in_dim, num_units, init_scale=0.1):
+ super().__init__()
+ self.W = nn.Parameter(
+ default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True
+ )
+ self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 3, 1)
+ y = contract_inner(x, self.W) + self.b
+ return y.permute(0, 3, 1, 2)
+
+
+class AttnBlock(nn.Module):
+ """Channel-wise self-attention block."""
+
+ def __init__(self, channels):
+ super().__init__()
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
+ self.NIN_0 = NIN(channels, channels)
+ self.NIN_1 = NIN(channels, channels)
+ self.NIN_2 = NIN(channels, channels)
+ self.NIN_3 = NIN(channels, channels, init_scale=0.0)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ h = self.GroupNorm_0(x)
+ q = self.NIN_0(h)
+ k = self.NIN_1(h)
+ v = self.NIN_2(h)
+
+ w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
+ w = torch.reshape(w, (B, H, W, H * W))
+ w = F.softmax(w, dim=-1)
+ w = torch.reshape(w, (B, H, W, H, W))
+ h = torch.einsum("bhwij,bcij->bchw", w, v)
+ h = self.NIN_3(h)
+ return x + h
+
+
+class Upsample(nn.Module):
+ def __init__(self, channels, with_conv=False):
+ super().__init__()
+ if with_conv:
+ self.Conv_0 = ddpm_conv3x3(channels, channels)
+ self.with_conv = with_conv
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ h = F.interpolate(x, (H * 2, W * 2), mode="nearest")
+ if self.with_conv:
+ h = self.Conv_0(h)
+ return h
+
+
+class Downsample(nn.Module):
+ def __init__(self, channels, with_conv=False):
+ super().__init__()
+ if with_conv:
+ self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
+ self.with_conv = with_conv
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # Emulate 'SAME' padding
+ if self.with_conv:
+ x = F.pad(x, (0, 1, 0, 1))
+ x = self.Conv_0(x)
+ else:
+ x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
+
+ assert x.shape == (B, C, H // 2, W // 2)
+ return x
+
+
+class ResnetBlockDDPM(nn.Module):
+ """The ResNet Blocks used in DDPM."""
+
+ def __init__(
+ self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1
+ ):
+ super().__init__()
+ if out_ch is None:
+ out_ch = in_ch
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
+ self.act = act
+ self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
+ if temb_dim is not None:
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
+ nn.init.zeros_(self.Dense_0.bias)
+
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
+ self.Dropout_0 = nn.Dropout(dropout)
+ self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.0)
+ if in_ch != out_ch:
+ if conv_shortcut:
+ self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
+ else:
+ self.NIN_0 = NIN(in_ch, out_ch)
+ self.out_ch = out_ch
+ self.in_ch = in_ch
+ self.conv_shortcut = conv_shortcut
+
+ def forward(self, x, temb=None):
+ B, C, H, W = x.shape
+ assert C == self.in_ch
+ out_ch = self.out_ch if self.out_ch else self.in_ch
+ h = self.act(self.GroupNorm_0(x))
+ h = self.Conv_0(h)
+ # Add bias to each feature map conditioned on the time embedding
+ if temb is not None:
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
+ h = self.act(self.GroupNorm_1(h))
+ h = self.Dropout_0(h)
+ h = self.Conv_1(h)
+ if C != out_ch:
+ if self.conv_shortcut:
+ x = self.Conv_2(x)
+ else:
+ x = self.NIN_0(x)
+ return x + h
diff --git a/modules/sgmse/ncsnpp_utils/layerspp.py b/modules/sgmse/ncsnpp_utils/layerspp.py
new file mode 100644
index 00000000..793b7e24
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/layerspp.py
@@ -0,0 +1,323 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+"""Layers for defining NCSN++.
+"""
+from . import layers
+from . import up_or_down_sampling
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+conv1x1 = layers.ddpm_conv1x1
+conv3x3 = layers.ddpm_conv3x3
+NIN = layers.NIN
+default_init = layers.default_init
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(self, embedding_size=256, scale=1.0):
+ super().__init__()
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ def forward(self, x):
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+
+
+class Combine(nn.Module):
+ """Combine information from skip connections."""
+
+ def __init__(self, dim1, dim2, method="cat"):
+ super().__init__()
+ self.Conv_0 = conv1x1(dim1, dim2)
+ self.method = method
+
+ def forward(self, x, y):
+ h = self.Conv_0(x)
+ if self.method == "cat":
+ return torch.cat([h, y], dim=1)
+ elif self.method == "sum":
+ return h + y
+ else:
+ raise ValueError(f"Method {self.method} not recognized.")
+
+
+class AttnBlockpp(nn.Module):
+ """Channel-wise self-attention block. Modified from DDPM."""
+
+ def __init__(self, channels, skip_rescale=False, init_scale=0.0):
+ super().__init__()
+ self.GroupNorm_0 = nn.GroupNorm(
+ num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6
+ )
+ self.NIN_0 = NIN(channels, channels)
+ self.NIN_1 = NIN(channels, channels)
+ self.NIN_2 = NIN(channels, channels)
+ self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
+ self.skip_rescale = skip_rescale
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ h = self.GroupNorm_0(x)
+ q = self.NIN_0(h)
+ k = self.NIN_1(h)
+ v = self.NIN_2(h)
+
+ w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
+ w = torch.reshape(w, (B, H, W, H * W))
+ w = F.softmax(w, dim=-1)
+ w = torch.reshape(w, (B, H, W, H, W))
+ h = torch.einsum("bhwij,bcij->bchw", w, v)
+ h = self.NIN_3(h)
+ if not self.skip_rescale:
+ return x + h
+ else:
+ return (x + h) / np.sqrt(2.0)
+
+
+class Upsample(nn.Module):
+ def __init__(
+ self,
+ in_ch=None,
+ out_ch=None,
+ with_conv=False,
+ fir=False,
+ fir_kernel=(1, 3, 3, 1),
+ ):
+ super().__init__()
+ out_ch = out_ch if out_ch else in_ch
+ if not fir:
+ if with_conv:
+ self.Conv_0 = conv3x3(in_ch, out_ch)
+ else:
+ if with_conv:
+ self.Conv2d_0 = up_or_down_sampling.Conv2d(
+ in_ch,
+ out_ch,
+ kernel=3,
+ up=True,
+ resample_kernel=fir_kernel,
+ use_bias=True,
+ kernel_init=default_init(),
+ )
+ self.fir = fir
+ self.with_conv = with_conv
+ self.fir_kernel = fir_kernel
+ self.out_ch = out_ch
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ if not self.fir:
+ h = F.interpolate(x, (H * 2, W * 2), "nearest")
+ if self.with_conv:
+ h = self.Conv_0(h)
+ else:
+ if not self.with_conv:
+ h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
+ else:
+ h = self.Conv2d_0(x)
+
+ return h
+
+
+class Downsample(nn.Module):
+ def __init__(
+ self,
+ in_ch=None,
+ out_ch=None,
+ with_conv=False,
+ fir=False,
+ fir_kernel=(1, 3, 3, 1),
+ ):
+ super().__init__()
+ out_ch = out_ch if out_ch else in_ch
+ if not fir:
+ if with_conv:
+ self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
+ else:
+ if with_conv:
+ self.Conv2d_0 = up_or_down_sampling.Conv2d(
+ in_ch,
+ out_ch,
+ kernel=3,
+ down=True,
+ resample_kernel=fir_kernel,
+ use_bias=True,
+ kernel_init=default_init(),
+ )
+ self.fir = fir
+ self.fir_kernel = fir_kernel
+ self.with_conv = with_conv
+ self.out_ch = out_ch
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ if not self.fir:
+ if self.with_conv:
+ x = F.pad(x, (0, 1, 0, 1))
+ x = self.Conv_0(x)
+ else:
+ x = F.avg_pool2d(x, 2, stride=2)
+ else:
+ if not self.with_conv:
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
+ else:
+ x = self.Conv2d_0(x)
+
+ return x
+
+
+class ResnetBlockDDPMpp(nn.Module):
+ """ResBlock adapted from DDPM."""
+
+ def __init__(
+ self,
+ act,
+ in_ch,
+ out_ch=None,
+ temb_dim=None,
+ conv_shortcut=False,
+ dropout=0.1,
+ skip_rescale=False,
+ init_scale=0.0,
+ ):
+ super().__init__()
+ out_ch = out_ch if out_ch else in_ch
+ self.GroupNorm_0 = nn.GroupNorm(
+ num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6
+ )
+ self.Conv_0 = conv3x3(in_ch, out_ch)
+ if temb_dim is not None:
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
+ nn.init.zeros_(self.Dense_0.bias)
+ self.GroupNorm_1 = nn.GroupNorm(
+ num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6
+ )
+ self.Dropout_0 = nn.Dropout(dropout)
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
+ if in_ch != out_ch:
+ if conv_shortcut:
+ self.Conv_2 = conv3x3(in_ch, out_ch)
+ else:
+ self.NIN_0 = NIN(in_ch, out_ch)
+
+ self.skip_rescale = skip_rescale
+ self.act = act
+ self.out_ch = out_ch
+ self.conv_shortcut = conv_shortcut
+
+ def forward(self, x, temb=None):
+ h = self.act(self.GroupNorm_0(x))
+ h = self.Conv_0(h)
+ if temb is not None:
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
+ h = self.act(self.GroupNorm_1(h))
+ h = self.Dropout_0(h)
+ h = self.Conv_1(h)
+ if x.shape[1] != self.out_ch:
+ if self.conv_shortcut:
+ x = self.Conv_2(x)
+ else:
+ x = self.NIN_0(x)
+ if not self.skip_rescale:
+ return x + h
+ else:
+ return (x + h) / np.sqrt(2.0)
+
+
+class ResnetBlockBigGANpp(nn.Module):
+ def __init__(
+ self,
+ act,
+ in_ch,
+ out_ch=None,
+ temb_dim=None,
+ up=False,
+ down=False,
+ dropout=0.1,
+ fir=False,
+ fir_kernel=(1, 3, 3, 1),
+ skip_rescale=True,
+ init_scale=0.0,
+ ):
+ super().__init__()
+
+ out_ch = out_ch if out_ch else in_ch
+ self.GroupNorm_0 = nn.GroupNorm(
+ num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6
+ )
+ self.up = up
+ self.down = down
+ self.fir = fir
+ self.fir_kernel = fir_kernel
+
+ self.Conv_0 = conv3x3(in_ch, out_ch)
+ if temb_dim is not None:
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
+ nn.init.zeros_(self.Dense_0.bias)
+
+ self.GroupNorm_1 = nn.GroupNorm(
+ num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6
+ )
+ self.Dropout_0 = nn.Dropout(dropout)
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
+ if in_ch != out_ch or up or down:
+ self.Conv_2 = conv1x1(in_ch, out_ch)
+
+ self.skip_rescale = skip_rescale
+ self.act = act
+ self.in_ch = in_ch
+ self.out_ch = out_ch
+
+ def forward(self, x, temb=None):
+ h = self.act(self.GroupNorm_0(x))
+
+ if self.up:
+ if self.fir:
+ h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
+ x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
+ else:
+ h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
+ x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
+ elif self.down:
+ if self.fir:
+ h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
+ else:
+ h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
+ x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
+
+ h = self.Conv_0(h)
+ # Add bias to each feature map conditioned on the time embedding
+ if temb is not None:
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
+ h = self.act(self.GroupNorm_1(h))
+ h = self.Dropout_0(h)
+ h = self.Conv_1(h)
+
+ if self.in_ch != self.out_ch or self.up or self.down:
+ x = self.Conv_2(x)
+
+ if not self.skip_rescale:
+ return x + h
+ else:
+ return (x + h) / np.sqrt(2.0)
diff --git a/modules/sgmse/ncsnpp_utils/normalization.py b/modules/sgmse/ncsnpp_utils/normalization.py
new file mode 100644
index 00000000..fcc4707e
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/normalization.py
@@ -0,0 +1,243 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Normalization layers."""
+import torch.nn as nn
+import torch
+import functools
+
+
+def get_normalization(config, conditional=False):
+ """Obtain normalization modules from the config file."""
+ norm = config.model.normalization
+ if conditional:
+ if norm == "InstanceNorm++":
+ return functools.partial(
+ ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes
+ )
+ else:
+ raise NotImplementedError(f"{norm} not implemented yet.")
+ else:
+ if norm == "InstanceNorm":
+ return nn.InstanceNorm2d
+ elif norm == "InstanceNorm++":
+ return InstanceNorm2dPlus
+ elif norm == "VarianceNorm":
+ return VarianceNorm2d
+ elif norm == "GroupNorm":
+ return nn.GroupNorm
+ else:
+ raise ValueError("Unknown normalization: %s" % norm)
+
+
+class ConditionalBatchNorm2d(nn.Module):
+ def __init__(self, num_features, num_classes, bias=True):
+ super().__init__()
+ self.num_features = num_features
+ self.bias = bias
+ self.bn = nn.BatchNorm2d(num_features, affine=False)
+ if self.bias:
+ self.embed = nn.Embedding(num_classes, num_features * 2)
+ self.embed.weight.data[
+ :, :num_features
+ ].uniform_() # Initialise scale at N(1, 0.02)
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
+ else:
+ self.embed = nn.Embedding(num_classes, num_features)
+ self.embed.weight.data.uniform_()
+
+ def forward(self, x, y):
+ out = self.bn(x)
+ if self.bias:
+ gamma, beta = self.embed(y).chunk(2, dim=1)
+ out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(
+ -1, self.num_features, 1, 1
+ )
+ else:
+ gamma = self.embed(y)
+ out = gamma.view(-1, self.num_features, 1, 1) * out
+ return out
+
+
+class ConditionalInstanceNorm2d(nn.Module):
+ def __init__(self, num_features, num_classes, bias=True):
+ super().__init__()
+ self.num_features = num_features
+ self.bias = bias
+ self.instance_norm = nn.InstanceNorm2d(
+ num_features, affine=False, track_running_stats=False
+ )
+ if bias:
+ self.embed = nn.Embedding(num_classes, num_features * 2)
+ self.embed.weight.data[
+ :, :num_features
+ ].uniform_() # Initialise scale at N(1, 0.02)
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
+ else:
+ self.embed = nn.Embedding(num_classes, num_features)
+ self.embed.weight.data.uniform_()
+
+ def forward(self, x, y):
+ h = self.instance_norm(x)
+ if self.bias:
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(
+ -1, self.num_features, 1, 1
+ )
+ else:
+ gamma = self.embed(y)
+ out = gamma.view(-1, self.num_features, 1, 1) * h
+ return out
+
+
+class ConditionalVarianceNorm2d(nn.Module):
+ def __init__(self, num_features, num_classes, bias=False):
+ super().__init__()
+ self.num_features = num_features
+ self.bias = bias
+ self.embed = nn.Embedding(num_classes, num_features)
+ self.embed.weight.data.normal_(1, 0.02)
+
+ def forward(self, x, y):
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
+ h = x / torch.sqrt(vars + 1e-5)
+
+ gamma = self.embed(y)
+ out = gamma.view(-1, self.num_features, 1, 1) * h
+ return out
+
+
+class VarianceNorm2d(nn.Module):
+ def __init__(self, num_features, bias=False):
+ super().__init__()
+ self.num_features = num_features
+ self.bias = bias
+ self.alpha = nn.Parameter(torch.zeros(num_features))
+ self.alpha.data.normal_(1, 0.02)
+
+ def forward(self, x):
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
+ h = x / torch.sqrt(vars + 1e-5)
+
+ out = self.alpha.view(-1, self.num_features, 1, 1) * h
+ return out
+
+
+class ConditionalNoneNorm2d(nn.Module):
+ def __init__(self, num_features, num_classes, bias=True):
+ super().__init__()
+ self.num_features = num_features
+ self.bias = bias
+ if bias:
+ self.embed = nn.Embedding(num_classes, num_features * 2)
+ self.embed.weight.data[
+ :, :num_features
+ ].uniform_() # Initialise scale at N(1, 0.02)
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
+ else:
+ self.embed = nn.Embedding(num_classes, num_features)
+ self.embed.weight.data.uniform_()
+
+ def forward(self, x, y):
+ if self.bias:
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
+ out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(
+ -1, self.num_features, 1, 1
+ )
+ else:
+ gamma = self.embed(y)
+ out = gamma.view(-1, self.num_features, 1, 1) * x
+ return out
+
+
+class NoneNorm2d(nn.Module):
+ def __init__(self, num_features, bias=True):
+ super().__init__()
+
+ def forward(self, x):
+ return x
+
+
+class InstanceNorm2dPlus(nn.Module):
+ def __init__(self, num_features, bias=True):
+ super().__init__()
+ self.num_features = num_features
+ self.bias = bias
+ self.instance_norm = nn.InstanceNorm2d(
+ num_features, affine=False, track_running_stats=False
+ )
+ self.alpha = nn.Parameter(torch.zeros(num_features))
+ self.gamma = nn.Parameter(torch.zeros(num_features))
+ self.alpha.data.normal_(1, 0.02)
+ self.gamma.data.normal_(1, 0.02)
+ if bias:
+ self.beta = nn.Parameter(torch.zeros(num_features))
+
+ def forward(self, x):
+ means = torch.mean(x, dim=(2, 3))
+ m = torch.mean(means, dim=-1, keepdim=True)
+ v = torch.var(means, dim=-1, keepdim=True)
+ means = (means - m) / (torch.sqrt(v + 1e-5))
+ h = self.instance_norm(x)
+
+ if self.bias:
+ h = h + means[..., None, None] * self.alpha[..., None, None]
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(
+ -1, self.num_features, 1, 1
+ )
+ else:
+ h = h + means[..., None, None] * self.alpha[..., None, None]
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h
+ return out
+
+
+class ConditionalInstanceNorm2dPlus(nn.Module):
+ def __init__(self, num_features, num_classes, bias=True):
+ super().__init__()
+ self.num_features = num_features
+ self.bias = bias
+ self.instance_norm = nn.InstanceNorm2d(
+ num_features, affine=False, track_running_stats=False
+ )
+ if bias:
+ self.embed = nn.Embedding(num_classes, num_features * 3)
+ self.embed.weight.data[:, : 2 * num_features].normal_(
+ 1, 0.02
+ ) # Initialise scale at N(1, 0.02)
+ self.embed.weight.data[
+ :, 2 * num_features :
+ ].zero_() # Initialise bias at 0
+ else:
+ self.embed = nn.Embedding(num_classes, 2 * num_features)
+ self.embed.weight.data.normal_(1, 0.02)
+
+ def forward(self, x, y):
+ means = torch.mean(x, dim=(2, 3))
+ m = torch.mean(means, dim=-1, keepdim=True)
+ v = torch.var(means, dim=-1, keepdim=True)
+ means = (means - m) / (torch.sqrt(v + 1e-5))
+ h = self.instance_norm(x)
+
+ if self.bias:
+ gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
+ h = h + means[..., None, None] * alpha[..., None, None]
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(
+ -1, self.num_features, 1, 1
+ )
+ else:
+ gamma, alpha = self.embed(y).chunk(2, dim=-1)
+ h = h + means[..., None, None] * alpha[..., None, None]
+ out = gamma.view(-1, self.num_features, 1, 1) * h
+ return out
diff --git a/modules/sgmse/ncsnpp_utils/op/__init__.py b/modules/sgmse/ncsnpp_utils/op/__init__.py
new file mode 100644
index 00000000..d0918d92
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/op/__init__.py
@@ -0,0 +1,2 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+from .upfirdn2d import upfirdn2d
diff --git a/modules/sgmse/ncsnpp_utils/op/fused_act.py b/modules/sgmse/ncsnpp_utils/op/fused_act.py
new file mode 100644
index 00000000..9f6cd311
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/op/fused_act.py
@@ -0,0 +1,97 @@
+import os
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.autograd import Function
+from torch.utils.cpp_extension import load
+
+
+module_path = os.path.dirname(__file__)
+fused = load(
+ "fused",
+ sources=[
+ os.path.join(module_path, "fused_bias_act.cpp"),
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
+ ],
+)
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused.fused_bias_act(
+ grad_output, empty, out, 3, 1, negative_slope, scale
+ )
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ (out,) = ctx.saved_tensors
+ gradgrad_out = fused.fused_bias_act(
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
+ )
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ (out,) = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
+ grad_output, out, ctx.negative_slope, ctx.scale
+ )
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+ if input.device.type == "cpu":
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
+ return (
+ F.leaky_relu(
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
+ )
+ * scale
+ )
+
+ else:
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/modules/sgmse/ncsnpp_utils/op/fused_bias_act.cpp b/modules/sgmse/ncsnpp_utils/op/fused_bias_act.cpp
new file mode 100644
index 00000000..a0543187
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/op/fused_bias_act.cpp
@@ -0,0 +1,21 @@
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
\ No newline at end of file
diff --git a/modules/sgmse/ncsnpp_utils/op/fused_bias_act_kernel.cu b/modules/sgmse/ncsnpp_utils/op/fused_bias_act_kernel.cu
new file mode 100644
index 00000000..8d2f03c7
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/op/fused_bias_act_kernel.cu
@@ -0,0 +1,99 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
\ No newline at end of file
diff --git a/modules/sgmse/ncsnpp_utils/op/upfirdn2d.cpp b/modules/sgmse/ncsnpp_utils/op/upfirdn2d.cpp
new file mode 100644
index 00000000..b07aa205
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/op/upfirdn2d.cpp
@@ -0,0 +1,23 @@
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
\ No newline at end of file
diff --git a/modules/sgmse/ncsnpp_utils/op/upfirdn2d.py b/modules/sgmse/ncsnpp_utils/op/upfirdn2d.py
new file mode 100644
index 00000000..e18039c9
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/op/upfirdn2d.py
@@ -0,0 +1,200 @@
+import os
+
+import torch
+from torch.nn import functional as F
+from torch.autograd import Function
+from torch.utils.cpp_extension import load
+
+
+module_path = os.path.dirname(__file__)
+upfirdn2d_op = load(
+ "upfirdn2d",
+ sources=[
+ os.path.join(module_path, "upfirdn2d.cpp"),
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
+ ],
+)
+
+
+class UpFirDn2dBackward(Function):
+ @staticmethod
+ def forward(
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
+ ):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_op.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ (kernel,) = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
+ )
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_op.upfirdn2d(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+ )
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ if input.device.type == "cpu":
+ out = upfirdn2d_native(
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
+ )
+
+ else:
+ out = UpFirDn2d.apply(
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
+ )
+
+ return out
+
+
+def upfirdn2d_native(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
+ )
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
+ )
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/modules/sgmse/ncsnpp_utils/op/upfirdn2d_kernel.cu b/modules/sgmse/ncsnpp_utils/op/upfirdn2d_kernel.cu
new file mode 100644
index 00000000..ed3eea30
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/op/upfirdn2d_kernel.cu
@@ -0,0 +1,369 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+template
+__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ int out_y = minor_idx / p.minor_dim;
+ minor_idx -= out_y * p.minor_dim;
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major && major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, out_x = out_x_base;
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
+
+ const scalar_t *x_p =
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
+ minor_idx];
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
+ int x_px = p.minor_dim;
+ int k_px = -p.up_x;
+ int x_py = p.in_w * p.minor_dim;
+ int k_py = -p.up_y * p.kernel_w;
+
+ scalar_t v = 0.0f;
+
+ for (int y = 0; y < h; y++) {
+ for (int x = 0; x < w; x++) {
+ v += static_cast(*x_p) * static_cast(*k_p);
+ x_p += x_px;
+ k_p += k_px;
+ }
+
+ x_p += x_py - w * x_px;
+ k_p += k_py - w * k_px;
+ }
+
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+}
+
+template
+__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
+ tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major & major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
+ loop_x < p.loop_x & tile_out_x < p.out_w;
+ loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
+ in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
+ p.minor_dim +
+ minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
+ out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+#pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+#pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] *
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+torch::Tensor upfirdn2d_op(const torch::Tensor &input,
+ const torch::Tensor &kernel, int up_x, int up_y,
+ int down_x, int down_y, int pad_x0, int pad_x1,
+ int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
+ p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
+ p.down_x;
+
+ auto out =
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h = -1;
+ int tile_out_w = -1;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w > 0) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ } else {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 4;
+ block_size = dim3(4, 32, 1);
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 3:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 4:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 5:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 6:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ default:
+ upfirdn2d_kernel_large<<>>(
+ out.data_ptr(), x.data_ptr(),
+ k.data_ptr(), p);
+ }
+ });
+
+ return out;
+}
\ No newline at end of file
diff --git a/modules/sgmse/ncsnpp_utils/up_or_down_sampling.py b/modules/sgmse/ncsnpp_utils/up_or_down_sampling.py
new file mode 100644
index 00000000..5d59071f
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/up_or_down_sampling.py
@@ -0,0 +1,273 @@
+"""Layers used for up-sampling or down-sampling images.
+
+Many functions are ported from https://github.com/NVlabs/stylegan2.
+"""
+
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+import numpy as np
+from .op import upfirdn2d
+
+
+# Function ported from StyleGAN2
+def get_weight(module, shape, weight_var="weight", kernel_init=None):
+ """Get/create weight tensor for a convolution or fully-connected layer."""
+
+ return module.param(weight_var, kernel_init, shape)
+
+
+class Conv2d(nn.Module):
+ """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
+
+ def __init__(
+ self,
+ in_ch,
+ out_ch,
+ kernel,
+ up=False,
+ down=False,
+ resample_kernel=(1, 3, 3, 1),
+ use_bias=True,
+ kernel_init=None,
+ ):
+ super().__init__()
+ assert not (up and down)
+ assert kernel >= 1 and kernel % 2 == 1
+ self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
+ if kernel_init is not None:
+ self.weight.data = kernel_init(self.weight.data.shape)
+ if use_bias:
+ self.bias = nn.Parameter(torch.zeros(out_ch))
+
+ self.up = up
+ self.down = down
+ self.resample_kernel = resample_kernel
+ self.kernel = kernel
+ self.use_bias = use_bias
+
+ def forward(self, x):
+ if self.up:
+ x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
+ elif self.down:
+ x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
+ else:
+ x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
+
+ if self.use_bias:
+ x = x + self.bias.reshape(1, -1, 1, 1)
+
+ return x
+
+
+def naive_upsample_2d(x, factor=2):
+ _N, C, H, W = x.shape
+ x = torch.reshape(x, (-1, C, H, 1, W, 1))
+ x = x.repeat(1, 1, 1, factor, 1, factor)
+ return torch.reshape(x, (-1, C, H * factor, W * factor))
+
+
+def naive_downsample_2d(x, factor=2):
+ _N, C, H, W = x.shape
+ x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
+ return torch.mean(x, dim=(3, 5))
+
+
+def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
+ """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
+
+ Padding is performed only once at the beginning, not between the
+ operations.
+ The fused op is considerably more efficient than performing the same
+ calculation
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels =
+ x.shape[0] // numGroups`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to
+ nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+
+ # Check weight shape.
+ assert len(w.shape) == 4
+ convH = w.shape[2]
+ convW = w.shape[3]
+ inC = w.shape[1]
+ outC = w.shape[0]
+
+ assert convW == convH
+
+ # Setup filter kernel.
+ if k is None:
+ k = [1] * factor
+ k = _setup_kernel(k) * (gain * (factor**2))
+ p = (k.shape[0] - factor) - (convW - 1)
+
+ stride = (factor, factor)
+
+ # Determine data dimensions.
+ stride = [1, 1, factor, factor]
+ output_shape = (
+ (_shape(x, 2) - 1) * factor + convH,
+ (_shape(x, 3) - 1) * factor + convW,
+ )
+ output_padding = (
+ output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
+ output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW,
+ )
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ num_groups = _shape(x, 1) // inC
+
+ # Transpose weights.
+ w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
+ w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
+ w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
+
+ x = F.conv_transpose2d(
+ x, w, stride=stride, output_padding=output_padding, padding=0
+ )
+ ## Original TF code.
+ # x = tf.nn.conv2d_transpose(
+ # x,
+ # w,
+ # output_shape=output_shape,
+ # strides=stride,
+ # padding='VALID',
+ # data_format=data_format)
+ ## JAX equivalent
+
+ return upfirdn2d(
+ x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)
+ )
+
+
+def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
+ """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
+
+ Padding is performed only once at the beginning, not between the operations.
+ The fused op is considerably more efficient than performing the same
+ calculation
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels =
+ x.shape[0] // numGroups`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to
+ average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ _outC, _inC, convH, convW = w.shape
+ assert convW == convH
+ if k is None:
+ k = [1] * factor
+ k = _setup_kernel(k) * gain
+ p = (k.shape[0] - factor) + (convW - 1)
+ s = [factor, factor]
+ x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
+ return F.conv2d(x, w, stride=s, padding=0)
+
+
+def _setup_kernel(k):
+ k = np.asarray(k, dtype=np.float32)
+ if k.ndim == 1:
+ k = np.outer(k, k)
+ k /= np.sum(k)
+ assert k.ndim == 2
+ assert k.shape[0] == k.shape[1]
+ return k
+
+
+def _shape(x, dim):
+ return x.shape[dim]
+
+
+def upsample_2d(x, k=None, factor=2, gain=1):
+ r"""Upsample a batch of 2D images with the given filter.
+
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
+ and upsamples each image with the given filter. The filter is normalized so
+ that
+ if the input pixels are constant, they will be scaled by the specified
+ `gain`.
+ Pixels outside the image are assumed to be zero, and the filter is padded
+ with
+ zeros so that its shape is a multiple of the upsampling factor.
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to
+ nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if k is None:
+ k = [1] * factor
+ k = _setup_kernel(k) * (gain * (factor**2))
+ p = k.shape[0] - factor
+ return upfirdn2d(
+ x,
+ torch.tensor(k, device=x.device),
+ up=factor,
+ pad=((p + 1) // 2 + factor - 1, p // 2),
+ )
+
+
+def downsample_2d(x, k=None, factor=2, gain=1):
+ r"""Downsample a batch of 2D images with the given filter.
+
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
+ and downsamples each image with the given filter. The filter is normalized
+ so that
+ if the input pixels are constant, they will be scaled by the specified
+ `gain`.
+ Pixels outside the image are assumed to be zero, and the filter is padded
+ with
+ zeros so that its shape is a multiple of the downsampling factor.
+ Args:
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to
+ average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if k is None:
+ k = [1] * factor
+ k = _setup_kernel(k) * gain
+ p = k.shape[0] - factor
+ return upfirdn2d(
+ x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)
+ )
diff --git a/modules/sgmse/ncsnpp_utils/utils.py b/modules/sgmse/ncsnpp_utils/utils.py
new file mode 100644
index 00000000..38333da6
--- /dev/null
+++ b/modules/sgmse/ncsnpp_utils/utils.py
@@ -0,0 +1,192 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""All functions and modules related to model definition.
+"""
+
+import torch
+
+import numpy as np
+from ...sdes import OUVESDE, OUVPSDE
+
+
+_MODELS = {}
+
+
+def register_model(cls=None, *, name=None):
+ """A decorator for registering model classes."""
+
+ def _register(cls):
+ if name is None:
+ local_name = cls.__name__
+ else:
+ local_name = name
+ if local_name in _MODELS:
+ raise ValueError(f"Already registered model with name: {local_name}")
+ _MODELS[local_name] = cls
+ return cls
+
+ if cls is None:
+ return _register
+ else:
+ return _register(cls)
+
+
+def get_model(name):
+ return _MODELS[name]
+
+
+def get_sigmas(sigma_min, sigma_max, num_scales):
+ """Get sigmas --- the set of noise levels for SMLD from config files.
+ Args:
+ config: A ConfigDict object parsed from the config file
+ Returns:
+ sigmas: a jax numpy arrary of noise levels
+ """
+ sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
+
+ return sigmas
+
+
+def get_ddpm_params(config):
+ """Get betas and alphas --- parameters used in the original DDPM paper."""
+ num_diffusion_timesteps = 1000
+ # parameters need to be adapted if number of time steps differs from 1000
+ beta_start = config.model.beta_min / config.model.num_scales
+ beta_end = config.model.beta_max / config.model.num_scales
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
+ sqrt_1m_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
+
+ return {
+ "betas": betas,
+ "alphas": alphas,
+ "alphas_cumprod": alphas_cumprod,
+ "sqrt_alphas_cumprod": sqrt_alphas_cumprod,
+ "sqrt_1m_alphas_cumprod": sqrt_1m_alphas_cumprod,
+ "beta_min": beta_start * (num_diffusion_timesteps - 1),
+ "beta_max": beta_end * (num_diffusion_timesteps - 1),
+ "num_diffusion_timesteps": num_diffusion_timesteps,
+ }
+
+
+def create_model(config):
+ """Create the score model."""
+ model_name = config.model.name
+ score_model = get_model(model_name)(config)
+ score_model = score_model.to(config.device)
+ score_model = torch.nn.DataParallel(score_model)
+ return score_model
+
+
+def get_model_fn(model, train=False):
+ """Create a function to give the output of the score-based model.
+
+ Args:
+ model: The score model.
+ train: `True` for training and `False` for evaluation.
+
+ Returns:
+ A model function.
+ """
+
+ def model_fn(x, labels):
+ """Compute the output of the score-based model.
+
+ Args:
+ x: A mini-batch of input data.
+ labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
+ for different models.
+
+ Returns:
+ A tuple of (model output, new mutable states)
+ """
+ if not train:
+ model.eval()
+ return model(x, labels)
+ else:
+ model.train()
+ return model(x, labels)
+
+ return model_fn
+
+
+def get_score_fn(sde, model, train=False, continuous=False):
+ """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
+
+ Args:
+ sde: An `sde_lib.SDE` object that represents the forward SDE.
+ model: A score model.
+ train: `True` for training and `False` for evaluation.
+ continuous: If `True`, the score-based model is expected to directly take continuous time steps.
+
+ Returns:
+ A score function.
+ """
+ model_fn = get_model_fn(model, train=train)
+
+ if isinstance(sde, OUVPSDE):
+
+ def score_fn(x, t):
+ # Scale neural network output by standard deviation and flip sign
+ if continuous:
+ # For VP-trained models, t=0 corresponds to the lowest noise level
+ # The maximum value of time embedding is assumed to 999 for
+ # continuously-trained models.
+ labels = t * 999
+ score = model_fn(x, labels)
+ std = sde.marginal_prob(torch.zeros_like(x), t)[1]
+ else:
+ # For VP-trained models, t=0 corresponds to the lowest noise level
+ labels = t * (sde.N - 1)
+ score = model_fn(x, labels)
+ std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
+
+ score = -score / std[:, None, None, None]
+ return score
+
+ elif isinstance(sde, OUVESDE):
+
+ def score_fn(x, t):
+ if continuous:
+ labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
+ else:
+ # For VE-trained models, t=0 corresponds to the highest noise level
+ labels = sde.T - t
+ labels *= sde.N - 1
+ labels = torch.round(labels).long()
+
+ score = model_fn(x, labels)
+ return score
+
+ else:
+ raise NotImplementedError(
+ f"SDE class {sde.__class__.__name__} not yet supported."
+ )
+
+ return score_fn
+
+
+def to_flattened_numpy(x):
+ """Flatten a torch tensor `x` and convert it to numpy."""
+ return x.detach().cpu().numpy().reshape((-1,))
+
+
+def from_flattened_numpy(x, shape):
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
+ return torch.from_numpy(x.reshape(shape))
diff --git a/modules/sgmse/sampling/__init__.py b/modules/sgmse/sampling/__init__.py
new file mode 100644
index 00000000..3248dc62
--- /dev/null
+++ b/modules/sgmse/sampling/__init__.py
@@ -0,0 +1,169 @@
+# Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
+"""Various sampling methods."""
+from scipy import integrate
+import torch
+
+from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
+from .correctors import Corrector, CorrectorRegistry
+
+
+__all__ = [
+ "PredictorRegistry",
+ "CorrectorRegistry",
+ "Predictor",
+ "Corrector",
+ "get_sampler",
+]
+
+
+def to_flattened_numpy(x):
+ """Flatten a torch tensor `x` and convert it to numpy."""
+ return x.detach().cpu().numpy().reshape((-1,))
+
+
+def from_flattened_numpy(x, shape):
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
+ return torch.from_numpy(x.reshape(shape))
+
+
+def get_pc_sampler(
+ predictor_name,
+ corrector_name,
+ sde,
+ score_fn,
+ y,
+ denoise=True,
+ eps=3e-2,
+ snr=0.1,
+ corrector_steps=1,
+ probability_flow: bool = False,
+ intermediate=False,
+ **kwargs
+):
+ """Create a Predictor-Corrector (PC) sampler.
+
+ Args:
+ predictor_name: The name of a registered `sampling.Predictor`.
+ corrector_name: The name of a registered `sampling.Corrector`.
+ sde: An `sdes.SDE` object representing the forward SDE.
+ score_fn: A function (typically learned model) that predicts the score.
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
+ denoise: If `True`, add one-step denoising to the final samples.
+ eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
+ snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
+ N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
+
+ Returns:
+ A sampling function that returns samples and the number of function evaluations during sampling.
+ """
+ predictor_cls = PredictorRegistry.get_by_name(predictor_name)
+ corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
+ predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
+ corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
+
+ def pc_sampler():
+ """The PC sampler function."""
+ with torch.no_grad():
+ xt = sde.prior_sampling(y.shape, y).to(y.device)
+ timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
+ for i in range(sde.N):
+ t = timesteps[i]
+ vec_t = torch.ones(y.shape[0], device=y.device) * t
+ xt, xt_mean = corrector.update_fn(xt, vec_t, y)
+ xt, xt_mean = predictor.update_fn(xt, vec_t, y)
+ x_result = xt_mean if denoise else xt
+ ns = sde.N * (corrector.n_steps + 1)
+ return x_result, ns
+
+ return pc_sampler
+
+
+def get_ode_sampler(
+ sde,
+ score_fn,
+ y,
+ inverse_scaler=None,
+ denoise=True,
+ rtol=1e-5,
+ atol=1e-5,
+ method="RK45",
+ eps=3e-2,
+ device="cuda",
+ **kwargs
+):
+ """Probability flow ODE sampler with the black-box ODE solver.
+
+ Args:
+ sde: An `sdes.SDE` object representing the forward SDE.
+ score_fn: A function (typically learned model) that predicts the score.
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
+ inverse_scaler: The inverse data normalizer.
+ denoise: If `True`, add one-step denoising to final samples.
+ rtol: A `float` number. The relative tolerance level of the ODE solver.
+ atol: A `float` number. The absolute tolerance level of the ODE solver.
+ method: A `str`. The algorithm used for the black-box ODE solver.
+ See the documentation of `scipy.integrate.solve_ivp`.
+ eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
+ device: PyTorch device.
+
+ Returns:
+ A sampling function that returns samples and the number of function evaluations during sampling.
+ """
+ predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
+ rsde = sde.reverse(score_fn, probability_flow=True)
+
+ def denoise_update_fn(x):
+ vec_eps = torch.ones(x.shape[0], device=x.device) * eps
+ _, x = predictor.update_fn(x, vec_eps, y)
+ return x
+
+ def drift_fn(x, t, y):
+ """Get the drift function of the reverse-time SDE."""
+ return rsde.sde(x, t, y)[0]
+
+ def ode_sampler(z=None, **kwargs):
+ """The probability flow ODE sampler with black-box ODE solver.
+
+ Args:
+ model: A score model.
+ z: If present, generate samples from latent code `z`.
+ Returns:
+ samples, number of function evaluations.
+ """
+ with torch.no_grad():
+ # If not represent, sample the latent code from the prior distibution of the SDE.
+ x = sde.prior_sampling(y.shape, y).to(device)
+
+ def ode_func(t, x):
+ x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
+ vec_t = torch.ones(y.shape[0], device=x.device) * t
+ drift = drift_fn(x, vec_t, y)
+ return to_flattened_numpy(drift)
+
+ # Black-box ODE solver for the probability flow ODE
+ solution = integrate.solve_ivp(
+ ode_func,
+ (sde.T, eps),
+ to_flattened_numpy(x),
+ rtol=rtol,
+ atol=atol,
+ method=method,
+ **kwargs
+ )
+ nfe = solution.nfev
+ x = (
+ torch.tensor(solution.y[:, -1])
+ .reshape(y.shape)
+ .to(device)
+ .type(torch.complex64)
+ )
+
+ # Denoising is equivalent to running one predictor step without adding noise
+ if denoise:
+ x = denoise_update_fn(x)
+
+ if inverse_scaler is not None:
+ x = inverse_scaler(x)
+ return x, nfe
+
+ return ode_sampler
diff --git a/modules/sgmse/sampling/correctors.py b/modules/sgmse/sampling/correctors.py
new file mode 100644
index 00000000..4b995f4b
--- /dev/null
+++ b/modules/sgmse/sampling/correctors.py
@@ -0,0 +1,99 @@
+import abc
+import torch
+
+from modules.sgmse import sdes
+from utils.sgmse_util.registry import Registry
+
+
+CorrectorRegistry = Registry("Corrector")
+
+
+class Corrector(abc.ABC):
+ """The abstract class for a corrector algorithm."""
+
+ def __init__(self, sde, score_fn, snr, n_steps):
+ super().__init__()
+ self.rsde = sde.reverse(score_fn)
+ self.score_fn = score_fn
+ self.snr = snr
+ self.n_steps = n_steps
+
+ @abc.abstractmethod
+ def update_fn(self, x, t, *args):
+ """One update of the corrector.
+
+ Args:
+ x: A PyTorch tensor representing the current state
+ t: A PyTorch tensor representing the current time step.
+ *args: Possibly additional arguments, in particular `y` for OU processes
+
+ Returns:
+ x: A PyTorch tensor of the next state.
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
+ """
+ pass
+
+
+@CorrectorRegistry.register(name="langevin")
+class LangevinCorrector(Corrector):
+ def __init__(self, sde, score_fn, snr, n_steps):
+ super().__init__(sde, score_fn, snr, n_steps)
+ self.score_fn = score_fn
+ self.n_steps = n_steps
+ self.snr = snr
+
+ def update_fn(self, x, t, *args):
+ target_snr = self.snr
+ for _ in range(self.n_steps):
+ grad = self.score_fn(x, t, *args)
+ noise = torch.randn_like(x)
+ grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
+ step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0)
+ x_mean = x + step_size[:, None, None, None] * grad
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
+
+ return x, x_mean
+
+
+@CorrectorRegistry.register(name="ald")
+class AnnealedLangevinDynamics(Corrector):
+ """The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
+
+ def __init__(self, sde, score_fn, snr, n_steps):
+ super().__init__(sde, score_fn, snr, n_steps)
+ if not isinstance(sde, (sdes.OUVESDE,)):
+ raise NotImplementedError(
+ f"SDE class {sde.__class__.__name__} not yet supported."
+ )
+ self.sde = sde
+ self.score_fn = score_fn
+ self.snr = snr
+ self.n_steps = n_steps
+
+ def update_fn(self, x, t, *args):
+ n_steps = self.n_steps
+ target_snr = self.snr
+ std = self.sde.marginal_prob(x, t, *args)[1]
+
+ for _ in range(n_steps):
+ grad = self.score_fn(x, t, *args)
+ noise = torch.randn_like(x)
+ step_size = (target_snr * std) ** 2 * 2
+ x_mean = x + step_size[:, None, None, None] * grad
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
+
+ return x, x_mean
+
+
+@CorrectorRegistry.register(name="none")
+class NoneCorrector(Corrector):
+ """An empty corrector that does nothing."""
+
+ def __init__(self, *args, **kwargs):
+ self.snr = 0
+ self.n_steps = 0
+ pass
+
+ def update_fn(self, x, t, *args):
+ return x, x
diff --git a/modules/sgmse/sampling/predictors.py b/modules/sgmse/sampling/predictors.py
new file mode 100644
index 00000000..84723f18
--- /dev/null
+++ b/modules/sgmse/sampling/predictors.py
@@ -0,0 +1,78 @@
+import abc
+
+import torch
+import numpy as np
+
+from utils.sgmse_util.registry import Registry
+
+
+PredictorRegistry = Registry("Predictor")
+
+
+class Predictor(abc.ABC):
+ """The abstract class for a predictor algorithm."""
+
+ def __init__(self, sde, score_fn, probability_flow=False):
+ super().__init__()
+ self.sde = sde
+ self.rsde = sde.reverse(score_fn)
+ self.score_fn = score_fn
+ self.probability_flow = probability_flow
+
+ @abc.abstractmethod
+ def update_fn(self, x, t, *args):
+ """One update of the predictor.
+
+ Args:
+ x: A PyTorch tensor representing the current state
+ t: A Pytorch tensor representing the current time step.
+ *args: Possibly additional arguments, in particular `y` for OU processes
+
+ Returns:
+ x: A PyTorch tensor of the next state.
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
+ """
+ pass
+
+ def debug_update_fn(self, x, t, *args):
+ raise NotImplementedError(
+ f"Debug update function not implemented for predictor {self}."
+ )
+
+
+@PredictorRegistry.register("euler_maruyama")
+class EulerMaruyamaPredictor(Predictor):
+ def __init__(self, sde, score_fn, probability_flow=False):
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
+
+ def update_fn(self, x, t, *args):
+ dt = -1.0 / self.rsde.N
+ z = torch.randn_like(x)
+ f, g = self.rsde.sde(x, t, *args)
+ x_mean = x + f * dt
+ x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z
+ return x, x_mean
+
+
+@PredictorRegistry.register("reverse_diffusion")
+class ReverseDiffusionPredictor(Predictor):
+ def __init__(self, sde, score_fn, probability_flow=False):
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
+
+ def update_fn(self, x, t, *args):
+ f, g = self.rsde.discretize(x, t, *args)
+ z = torch.randn_like(x)
+ x_mean = x - f
+ x = x_mean + g[:, None, None, None] * z
+ return x, x_mean
+
+
+@PredictorRegistry.register("none")
+class NonePredictor(Predictor):
+ """An empty predictor that does nothing."""
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def update_fn(self, x, t, *args):
+ return x, x
diff --git a/modules/sgmse/sdes.py b/modules/sgmse/sdes.py
new file mode 100644
index 00000000..441311bb
--- /dev/null
+++ b/modules/sgmse/sdes.py
@@ -0,0 +1,360 @@
+"""
+Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
+
+Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
+"""
+
+import abc
+import warnings
+
+import numpy as np
+from utils.sgmse_util.tensors import batch_broadcast
+import torch
+
+from utils.sgmse_util.registry import Registry
+
+
+SDERegistry = Registry("SDE")
+
+
+class SDE(abc.ABC):
+ """SDE abstract class. Functions are designed for a mini-batch of inputs."""
+
+ def __init__(self, N):
+ """Construct an SDE.
+
+ Args:
+ N: number of discretization time steps.
+ """
+ super().__init__()
+ self.N = N
+
+ @property
+ @abc.abstractmethod
+ def T(self):
+ """End time of the SDE."""
+ pass
+
+ @abc.abstractmethod
+ def sde(self, x, t, *args):
+ pass
+
+ @abc.abstractmethod
+ def marginal_prob(self, x, t, *args):
+ """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
+ pass
+
+ @abc.abstractmethod
+ def prior_sampling(self, shape, *args):
+ """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
+ pass
+
+ @abc.abstractmethod
+ def prior_logp(self, z):
+ """Compute log-density of the prior distribution.
+
+ Useful for computing the log-likelihood via probability flow ODE.
+
+ Args:
+ z: latent code
+ Returns:
+ log probability density
+ """
+ pass
+
+ @staticmethod
+ @abc.abstractmethod
+ def add_argparse_args(parent_parser):
+ """
+ Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
+ """
+ pass
+
+ def discretize(self, x, t, *args):
+ """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
+
+ Useful for reverse diffusion sampling and probabiliy flow sampling.
+ Defaults to Euler-Maruyama discretization.
+
+ Args:
+ x: a torch tensor
+ t: a torch float representing the time step (from 0 to `self.T`)
+
+ Returns:
+ f, G
+ """
+ dt = 1 / self.N
+ drift, diffusion = self.sde(x, t, *args)
+ f = drift * dt
+ G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
+ return f, G
+
+ def reverse(oself, score_model, probability_flow=False):
+ """Create the reverse-time SDE/ODE.
+
+ Args:
+ score_model: A function that takes x, t and y and returns the score.
+ probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
+ """
+ N = oself.N
+ T = oself.T
+ sde_fn = oself.sde
+ discretize_fn = oself.discretize
+
+ # Build the class for reverse-time SDE.
+ class RSDE(oself.__class__):
+ def __init__(self):
+ self.N = N
+ self.probability_flow = probability_flow
+
+ @property
+ def T(self):
+ return T
+
+ def sde(self, x, t, *args):
+ """Create the drift and diffusion functions for the reverse SDE/ODE."""
+ rsde_parts = self.rsde_parts(x, t, *args)
+ total_drift, diffusion = (
+ rsde_parts["total_drift"],
+ rsde_parts["diffusion"],
+ )
+ return total_drift, diffusion
+
+ def rsde_parts(self, x, t, *args):
+ sde_drift, sde_diffusion = sde_fn(x, t, *args)
+ score = score_model(x, t, *args)
+ score_drift = (
+ -sde_diffusion[:, None, None, None] ** 2
+ * score
+ * (0.5 if self.probability_flow else 1.0)
+ )
+ diffusion = (
+ torch.zeros_like(sde_diffusion)
+ if self.probability_flow
+ else sde_diffusion
+ )
+ total_drift = sde_drift + score_drift
+ return {
+ "total_drift": total_drift,
+ "diffusion": diffusion,
+ "sde_drift": sde_drift,
+ "sde_diffusion": sde_diffusion,
+ "score_drift": score_drift,
+ "score": score,
+ }
+
+ def discretize(self, x, t, *args):
+ """Create discretized iteration rules for the reverse diffusion sampler."""
+ f, G = discretize_fn(x, t, *args)
+ rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, *args) * (
+ 0.5 if self.probability_flow else 1.0
+ )
+ rev_G = torch.zeros_like(G) if self.probability_flow else G
+ return rev_f, rev_G
+
+ return RSDE()
+
+ @abc.abstractmethod
+ def copy(self):
+ pass
+
+
+@SDERegistry.register("ouve")
+class OUVESDE(SDE):
+ @staticmethod
+ def add_argparse_args(parser):
+ parser.add_argument(
+ "--sde-n",
+ type=int,
+ default=1000,
+ help="The number of timesteps in the SDE discretization. 30 by default",
+ )
+ parser.add_argument(
+ "--theta",
+ type=float,
+ default=1.5,
+ help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.",
+ )
+ parser.add_argument(
+ "--sigma-min",
+ type=float,
+ default=0.05,
+ help="The minimum sigma to use. 0.05 by default.",
+ )
+ parser.add_argument(
+ "--sigma-max",
+ type=float,
+ default=0.5,
+ help="The maximum sigma to use. 0.5 by default.",
+ )
+ return parser
+
+ def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs):
+ """Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
+
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
+
+ dx = -theta (y-x) dt + sigma(t) dw
+
+ with
+
+ sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
+
+ Args:
+ theta: stiffness parameter.
+ sigma_min: smallest sigma.
+ sigma_max: largest sigma.
+ N: number of discretization steps
+ """
+ super().__init__(N)
+ self.theta = theta
+ self.sigma_min = sigma_min
+ self.sigma_max = sigma_max
+ self.logsig = np.log(self.sigma_max / self.sigma_min)
+ self.N = N
+
+ def copy(self):
+ return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
+
+ @property
+ def T(self):
+ return 1
+
+ def sde(self, x, t, y):
+ drift = self.theta * (y - x)
+ # the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
+ # standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
+ # with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
+ # unless this sqrt(2*logsig) factor is included.
+ sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
+ diffusion = sigma * np.sqrt(2 * self.logsig)
+ return drift, diffusion
+
+ def _mean(self, x0, t, y):
+ theta = self.theta
+ exp_interp = torch.exp(-theta * t)[:, None, None, None]
+ return exp_interp * x0 + (1 - exp_interp) * y
+
+ def _std(self, t):
+ # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
+ sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
+ # could maybe replace the two torch.exp(... * t) terms here by cached values **t
+ return torch.sqrt(
+ (
+ sigma_min**2
+ * torch.exp(-2 * theta * t)
+ * (torch.exp(2 * (theta + logsig) * t) - 1)
+ * logsig
+ )
+ / (theta + logsig)
+ )
+
+ def marginal_prob(self, x0, t, y):
+ return self._mean(x0, t, y), self._std(t)
+
+ def prior_sampling(self, shape, y):
+ if shape != y.shape:
+ warnings.warn(
+ f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape."
+ )
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
+ return x_T
+
+ def prior_logp(self, z):
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
+
+
+@SDERegistry.register("ouvp")
+class OUVPSDE(SDE):
+ # !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
+ @staticmethod
+ def add_argparse_args(parser):
+ parser.add_argument(
+ "--sde-n",
+ type=int,
+ default=1000,
+ help="The number of timesteps in the SDE discretization. 1000 by default",
+ )
+ parser.add_argument(
+ "--beta-min", type=float, required=True, help="The minimum beta to use."
+ )
+ parser.add_argument(
+ "--beta-max", type=float, required=True, help="The maximum beta to use."
+ )
+ parser.add_argument(
+ "--stiffness",
+ type=float,
+ default=1,
+ help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.",
+ )
+ return parser
+
+ def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
+ """
+ !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
+
+ Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
+
+ dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
+
+ with
+
+ beta(t) = beta_min + t(beta_max - beta_min)
+
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
+
+ Args:
+ beta_min: smallest sigma.
+ beta_max: largest sigma.
+ stiffness: stiffness factor of the drift. 1 by default.
+ N: number of discretization steps
+ """
+ super().__init__(N)
+ self.beta_min = beta_min
+ self.beta_max = beta_max
+ self.stiffness = stiffness
+ self.N = N
+
+ def copy(self):
+ return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
+
+ @property
+ def T(self):
+ return 1
+
+ def _beta(self, t):
+ return self.beta_min + t * (self.beta_max - self.beta_min)
+
+ def sde(self, x, t, y):
+ drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
+ diffusion = torch.sqrt(self._beta(t))
+ return drift, diffusion
+
+ def _mean(self, x0, t, y):
+ b0, b1, s = self.beta_min, self.beta_max, self.stiffness
+ x0y_fac = torch.exp(-0.25 * s * t * (t * (b1 - b0) + 2 * b0))[
+ :, None, None, None
+ ]
+ return y + x0y_fac * (x0 - y)
+
+ def _std(self, t):
+ b0, b1, s = self.beta_min, self.beta_max, self.stiffness
+ return (1 - torch.exp(-0.5 * s * t * (t * (b1 - b0) + 2 * b0))) / s
+
+ def marginal_prob(self, x0, t, y):
+ return self._mean(x0, t, y), self._std(t)
+
+ def prior_sampling(self, shape, y):
+ if shape != y.shape:
+ warnings.warn(
+ f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape."
+ )
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
+ return x_T
+
+ def prior_logp(self, z):
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
diff --git a/modules/sgmse/shared.py b/modules/sgmse/shared.py
new file mode 100644
index 00000000..8165069f
--- /dev/null
+++ b/modules/sgmse/shared.py
@@ -0,0 +1,132 @@
+import functools
+import numpy as np
+
+import torch
+import torch.nn as nn
+
+from utils.sgmse_util.registry import Registry
+
+
+BackboneRegistry = Registry("Backbone")
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian random features for encoding time steps."""
+
+ def __init__(self, embed_dim, scale=16, complex_valued=False):
+ super().__init__()
+ self.complex_valued = complex_valued
+ if not complex_valued:
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
+ # and this halving is not necessary.
+ embed_dim = embed_dim // 2
+ # Randomly sample weights during initialization. These weights are fixed
+ # during optimization and are not trainable.
+ self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
+
+ def forward(self, t):
+ t_proj = t[:, None] * self.W[None, :] * 2 * np.pi
+ if self.complex_valued:
+ return torch.exp(1j * t_proj)
+ else:
+ return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
+
+
+class DiffusionStepEmbedding(nn.Module):
+ """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
+
+ def __init__(self, embed_dim, complex_valued=False):
+ super().__init__()
+ self.complex_valued = complex_valued
+ if not complex_valued:
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
+ # and this halving is not necessary.
+ embed_dim = embed_dim // 2
+ self.embed_dim = embed_dim
+
+ def forward(self, t):
+ fac = 10 ** (
+ 4 * torch.arange(self.embed_dim, device=t.device) / (self.embed_dim - 1)
+ )
+ inner = t[:, None] * fac[None, :]
+ if self.complex_valued:
+ return torch.exp(1j * inner)
+ else:
+ return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
+
+
+class ComplexLinear(nn.Module):
+ """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
+
+ def __init__(self, input_dim, output_dim, complex_valued):
+ super().__init__()
+ self.complex_valued = complex_valued
+ if self.complex_valued:
+ self.re = nn.Linear(input_dim, output_dim)
+ self.im = nn.Linear(input_dim, output_dim)
+ else:
+ self.lin = nn.Linear(input_dim, output_dim)
+
+ def forward(self, x):
+ if self.complex_valued:
+ return (self.re(x.real) - self.im(x.imag)) + 1j * (
+ self.re(x.imag) + self.im(x.real)
+ )
+ else:
+ return self.lin(x)
+
+
+class FeatureMapDense(nn.Module):
+ """A fully connected layer that reshapes outputs to feature maps."""
+
+ def __init__(self, input_dim, output_dim, complex_valued=False):
+ super().__init__()
+ self.complex_valued = complex_valued
+ self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
+
+ def forward(self, x):
+ return self.dense(x)[..., None, None]
+
+
+def torch_complex_from_reim(re, im):
+ return torch.view_as_complex(torch.stack([re, im], dim=-1))
+
+
+class ArgsComplexMultiplicationWrapper(nn.Module):
+ """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
+
+ Make a complex-valued module `F` from a real-valued module `f` by applying
+ complex multiplication rules:
+
+ F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
+
+ where `f1`, `f2` are instances of `f` that do *not* share weights.
+
+ Args:
+ module_cls (callable): A class or function that returns a Torch module/functional.
+ Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
+ to construct the real and imaginary component modules.
+ """
+
+ def __init__(self, module_cls, *args, **kwargs):
+ super().__init__()
+ self.re_module = module_cls(*args, **kwargs)
+ self.im_module = module_cls(*args, **kwargs)
+
+ def forward(self, x, *args, **kwargs):
+ return torch_complex_from_reim(
+ self.re_module(x.real, *args, **kwargs)
+ - self.im_module(x.imag, *args, **kwargs),
+ self.re_module(x.imag, *args, **kwargs)
+ + self.im_module(x.real, *args, **kwargs),
+ )
+
+
+ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
+ComplexConvTranspose2d = functools.partial(
+ ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d
+)
diff --git a/preprocessors/wsj0reverb.py b/preprocessors/wsj0reverb.py
new file mode 100644
index 00000000..84b71c01
--- /dev/null
+++ b/preprocessors/wsj0reverb.py
@@ -0,0 +1,187 @@
+import json
+from tqdm import tqdm
+import os
+import torchaudio
+from utils import audio
+import csv
+import random
+from text import _clean_text
+import librosa
+import soundfile as sf
+import pyroomacoustics as pra
+from scipy.io import wavfile
+from glob import glob
+from pathlib import Path
+import numpy as np
+
+SEED = 100
+np.random.seed(SEED)
+
+T60_RANGE = [0.4, 1.0]
+SNR_RANGE = [0, 20]
+DIM_RANGE = [5, 15, 5, 15, 2, 6]
+MIN_DISTANCE_TO_WALL = 1
+MIC_ARRAY_RADIUS = 0.16
+TARGET_T60_SHAPE = {"CI": 0.08, "HA": 0.2}
+TARGET_T60_SHAPE = {"CI": 0.10, "HA": 0.2}
+TARGETS_CROP = {"CI": 16e-3, "HA": 40e-3}
+NB_SAMPLES_PER_ROOM = 1
+CHANNELS = 1
+
+
+def obtain_clean_file(speech_list, i_sample, sample_rate=16000):
+ speech, speech_sr = sf.read(speech_list[i_sample])
+ speech_basename = os.path.basename(speech_list[i_sample])
+ assert (
+ speech_sr == sample_rate
+ ), f"wrong speech sampling rate here: expected {sample_rate} got {speech_sr}"
+ return speech.squeeze(), speech_sr, speech_basename[:-4]
+
+
+def main(output_path, dataset_path):
+ print("-" * 10)
+ print("Dataset splits for {}...\n".format("wsj0reverb"))
+ dataset = "wsj0reverb"
+ sample_rate = 16000
+ save_dir = os.path.join(output_path, dataset)
+ os.makedirs(save_dir, exist_ok=True)
+ wsj0reverb_path = dataset_path
+ splits = ["valid", "train", "test"]
+ dic_split = {"valid": "si_dt_05", "train": "si_tr_s", "test": "si_et_05"}
+ speech_lists = {
+ split: sorted(
+ glob(f"{os.path.join(wsj0reverb_path, dic_split[split])}/**/*.wav")
+ )
+ for split in splits
+ }
+
+ for i_split, split in enumerate(splits):
+ print("Processing split n° {}: {}...".format(i_split + 1, split))
+
+ reverberant_output_dir = os.path.join(save_dir, "audio", split, "reverb")
+ dry_output_dir = os.path.join(save_dir, "audio", split, "anechoic")
+ noisy_reverberant_output_dir = os.path.join(
+ save_dir, "audio", split, "noisy_reverb"
+ )
+ if split == "test":
+ unauralized_output_dir = os.path.join(
+ save_dir, "audio", split, "unauralized"
+ )
+
+ os.makedirs(reverberant_output_dir, exist_ok=True)
+ os.makedirs(dry_output_dir, exist_ok=True)
+ if split == "test":
+ os.makedirs(unauralized_output_dir, exist_ok=True)
+
+ speech_list = speech_lists[split]
+ real_nb_samples = len(speech_list)
+ for i_sample in tqdm(range(real_nb_samples)):
+ if not i_sample % NB_SAMPLES_PER_ROOM: # Generate new room
+ t60 = np.random.uniform(T60_RANGE[0], T60_RANGE[1]) # Draw T60
+ room_dim = np.array(
+ [
+ np.random.uniform(DIM_RANGE[2 * n], DIM_RANGE[2 * n + 1])
+ for n in range(3)
+ ]
+ ) # Draw Dimensions
+ center_mic_position = np.array(
+ [
+ np.random.uniform(
+ MIN_DISTANCE_TO_WALL, room_dim[n] - MIN_DISTANCE_TO_WALL
+ )
+ for n in range(3)
+ ]
+ ) # draw source position
+ source_position = np.array(
+ [
+ np.random.uniform(
+ MIN_DISTANCE_TO_WALL, room_dim[n] - MIN_DISTANCE_TO_WALL
+ )
+ for n in range(3)
+ ]
+ ) # draw source position
+ mic_array_2d = pra.beamforming.circular_2D_array(
+ center_mic_position[:-1], CHANNELS, phi0=0, radius=MIC_ARRAY_RADIUS
+ ) # Compute microphone array
+ mic_array = np.pad(
+ mic_array_2d,
+ ((0, 1), (0, 0)),
+ mode="constant",
+ constant_values=center_mic_position[-1],
+ )
+
+ ### Reverberant Room
+ e_absorption, max_order = pra.inverse_sabine(
+ t60, room_dim
+ ) # Compute absorption coeff
+ reverberant_room = pra.ShoeBox(
+ room_dim,
+ fs=16000,
+ materials=pra.Material(e_absorption),
+ max_order=min(3, max_order),
+ ) # Create room
+ reverberant_room.set_ray_tracing()
+ reverberant_room.add_microphone_array(mic_array) # Add microphone array
+
+ # Pick unauralized files
+ speech, speech_sr, speech_basename = obtain_clean_file(
+ speech_list, i_sample, sample_rate=sample_rate
+ )
+
+ # Generate reverberant room
+ reverberant_room.add_source(source_position, signal=speech)
+ reverberant_room.compute_rir()
+ reverberant_room.simulate()
+ t60_real = np.mean(reverberant_room.measure_rt60()).squeeze()
+ reverberant = np.stack(reverberant_room.mic_array.signals).swapaxes(0, 1)
+
+ e_absorption_dry = 0.99 # For Neural Networks OK but clearly not for WPE
+ dry_room = pra.ShoeBox(
+ room_dim,
+ fs=16000,
+ materials=pra.Material(e_absorption_dry),
+ max_order=0,
+ ) # Create room
+ dry_room.add_microphone_array(mic_array) # Add microphone array
+
+ # Generate dry room
+ dry_room.add_source(source_position, signal=speech)
+ dry_room.compute_rir()
+ dry_room.simulate()
+ t60_real_dry = np.mean(dry_room.measure_rt60()).squeeze()
+ rir_dry = dry_room.rir
+ dry = np.stack(dry_room.mic_array.signals).swapaxes(0, 1)
+ dry = np.pad(
+ dry,
+ ((0, int(0.5 * sample_rate)), (0, 0)),
+ mode="constant",
+ constant_values=0,
+ ) # Add 1 second of silence after dry (because very dry) so that the reverb is not cut, and all samples have same length
+
+ min_len_sample = min(reverberant.shape[0], dry.shape[0])
+ dry = dry[:min_len_sample]
+ reverberant = reverberant[:min_len_sample]
+ output_scaling = np.max(reverberant) / 0.9
+
+ drr = 10 * np.log10(
+ np.mean(dry**2) / (np.mean(reverberant**2) + 1e-8) + 1e-8
+ )
+ output_filename = f"{speech_basename}_{i_sample // NB_SAMPLES_PER_ROOM}_{t60_real:.2f}_{drr:.1f}.wav"
+
+ sf.write(
+ os.path.join(dry_output_dir, output_filename),
+ 1 / output_scaling * dry,
+ samplerate=sample_rate,
+ )
+ sf.write(
+ os.path.join(reverberant_output_dir, output_filename),
+ 1 / output_scaling * reverberant,
+ samplerate=sample_rate,
+ )
+
+ if split == "test":
+ sf.write(
+ os.path.join(unauralized_output_dir, output_filename),
+ speech,
+ samplerate=sample_rate,
+ )