diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0202868 --- /dev/null +++ b/.gitignore @@ -0,0 +1,132 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# VSCode project settings +.vscode + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6eb2af0 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Benjamin van Niekerk + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..567a2ac --- /dev/null +++ b/README.md @@ -0,0 +1,92 @@ +# HiFi-GAN + +An 16kHz implementation of HiFi-GAN for [soft-vc](https://github.com/bshall/soft-vc). + +Relevant links: +- [Official HiFi-GAN repo](https://github.com/jik876/hifi-gan) +- [HiFi-GAN paper](https://arxiv.org/abs/2010.05646) +- [Soft-VC repo](https://github.com/bshall/soft-vc) +- [Soft-VC paper]() + +## Example Usage + +```python +import torch +import numpy as np + +# Load checkpoint +hifigan = torch.hub.load("bshall/hifigan:main", "hifigan-hubert-soft").cuda() +# Load mel-spectrogram +mel = torch.from_numpy(np.load("path/to/mel")).unsqueeze(0).cuda() +# Generate +wav, sr = hifigan.generate(mel) +``` + +## Train + +**Step 1**: Download and extract the [LJ-Speech dataset](https://keithito.com/LJ-Speech-Dataset/) + +**Step 2**: Resample the audio to 16kHz: +``` +usage: resample.py [-h] [--sample-rate SAMPLE_RATE] in-dir out-dir + +Resample an audio dataset. + +positional arguments: + in-dir path to the dataset directory + out-dir path to the output directory + +optional arguments: + -h, --help show this help message and exit + --sample-rate SAMPLE_RATE + target sample rate (default 16kHz) +``` + +**Step 3**: Download the dataset splits and move them into the root of the dataset directory. +After steps 2 and 3 your dataset directory should look like this: +``` +LJSpeech-1.1 +│ test.txt +│ train.txt +│ validation.txt +├───mels +└───wavs +``` +Note: the mels directory is optional. If you want to fine-tune HiFi-GAN the mels directory should contain ground-truth aligned spectrograms from an acoustic model. + +**Step 4**: Train HiFi-GAN: +``` +usage: train.py [-h] [--resume RESUME] [--finetune] dataset-dir checkpoint-dir + +Train or finetune HiFi-GAN. + +positional arguments: + dataset-dir path to the preprocessed data directory + checkpoint-dir path to the checkpoint directory + +optional arguments: + -h, --help show this help message and exit + --resume RESUME path to the checkpoint to resume from + --finetune whether to finetune (note that a resume path must be given) +``` + +## Generate +To generate using the trained HiFi-GAN models, see [Example Usage](#example-usage) or use the `generate.py` script: + +``` +usage: generate.py [-h] [--model-name {hifigan,hifigan-hubert-soft,hifigan-hubert-discrete}] in-dir out-dir + +Generate audio for a directory of mel-spectrogams using HiFi-GAN. + +positional arguments: + in-dir path to directory containing the mel-spectrograms + out-dir path to output directory + +optional arguments: + -h, --help show this help message and exit + --model-name {hifigan,hifigan-hubert-soft,hifigan-hubert-discrete} + available models +``` + +## Acknowledgements +This repo is based heavily on [https://github.com/jik876/hifi-gan](https://github.com/jik876/hifi-gan). \ No newline at end of file diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..76544af --- /dev/null +++ b/generate.py @@ -0,0 +1,50 @@ +from pathlib import Path +import numpy as np +import argparse +import torch +import torchaudio +from tqdm import tqdm + + +def generate(args): + args.out_dir.mkdir(exist_ok=True, parents=True) + + print("Loading checkpoint") + hifigan = torch.hub.load("bshall/hifigan:main", args.model_name).cuda() + + print(f"Generating audio from {args.in_dir}") + for path in tqdm(list(args.in_dir.rglob("*.npy"))): + mel = torch.from_numpy(np.load(path)) + mel = mel.unsqueeze(0).cuda() + + wav, sr = hifigan.generate(mel) + wav = wav.squeeze(0).cpu() + + out_path = args.out_dir / path.relative_to(args.in_dir) + out_path.parent.mkdir(exist_ok=True, parents=True) + torchaudio.save(out_path.with_suffix(".wav"), wav, sr) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate audio for a directory of mel-spectrogams using HiFi-GAN." + ) + parser.add_argument( + "in-dir", + help="path to directory containing the mel-spectrograms", + type=Path, + ) + parser.add_argument( + "out-dir", + help="path to output directory", + type=Path, + ) + parser.add_argument( + "--model-name", + help="available models", + choices=["hifigan", "hifigan-hubert-soft", "hifigan-hubert-discrete"], + default="hifigan-hubert-soft", + ) + args = parser.parse_args() + + generate(args) diff --git a/hifigan/dataset.py b/hifigan/dataset.py new file mode 100644 index 0000000..317996f --- /dev/null +++ b/hifigan/dataset.py @@ -0,0 +1,113 @@ +from pathlib import Path +import math +import random +import numpy as np +import torch +import torch.nn.functional as F + +from torch.utils.data import Dataset + +import torchaudio +import torchaudio.transforms as transforms + + +class LogMelSpectrogram(torch.nn.Module): + def __init__(self): + super().__init__() + self.melspctrogram = transforms.MelSpectrogram( + sample_rate=16000, + n_fft=1024, + win_length=1024, + hop_length=160, + center=False, + power=1.0, + norm="slaney", + onesided=True, + n_mels=128, + mel_scale="slaney", + ) + + def forward(self, wav): + wav = F.pad(wav, ((1024 - 160) // 2, (1024 - 160) // 2), "reflect") + mel = self.melspctrogram(wav) + logmel = torch.log(torch.clamp(mel, min=1e-5)) + return logmel + + +class MelDataset(Dataset): + def __init__( + self, root, segment_length, sample_rate, hop_length, train=True, finetune=False + ): + self.root = Path(root) + self.segment_length = segment_length + self.sample_rate = sample_rate + self.hop_length = hop_length + self.train = train + self.finetune = finetune + + split = "train.txt" if train else "validation.txt" + with open(self.root / split) as file: + self.metadata = [line.strip() for line in file] + + self.logmel = LogMelSpectrogram() + + def __len__(self): + return len(self.metadata) + + def __getitem__(self, index): + path = self.metadata[index] + wav_path = self.root / "wavs" / path + + info = torchaudio.info(wav_path.with_suffix(".wav")) + if info.sample_rate != self.sample_rate: + raise ValueError( + f"Sample rate {info.sample_rate} doesn't match target of {self.sample_rate}" + ) + + if self.finetune: + mel_path = self.root / "mels" / path + src_logmel = torch.from_numpy(np.load(mel_path.with_suffix(".npy"))) + src_logmel = src_logmel.unsqueeze(0) + + mel_frames_per_segment = math.ceil(self.segment_length / self.hop_length) + mel_diff = src_logmel.size(-1) - mel_frames_per_segment if self.train else 0 + mel_offset = random.randint(0, max(mel_diff, 0)) + + frame_offset = self.hop_length * mel_offset + else: + frame_diff = info.num_frames - self.segment_length + frame_offset = random.randint(0, max(frame_diff, 0)) + + wav, _ = torchaudio.load( + filepath=wav_path.with_suffix(".wav"), + frame_offset=frame_offset if self.train else 0, + num_frames=self.segment_length if self.train else -1, + ) + + if wav.size(-1) < self.segment_length: + wav = F.pad(wav, (0, self.segment_length - wav.size(-1))) + + if not self.finetune and self.train: + gain = random.random() * (0.99 - 0.4) + 0.4 + flip = -1 if random.random() > 0.5 else 1 + wav = flip * gain * wav / wav.abs().max() + + tgt_logmel = self.logmel(wav.unsqueeze(0)).squeeze(0) + + if self.finetune: + if self.train: + src_logmel = src_logmel[ + :, :, mel_offset : mel_offset + mel_frames_per_segment + ] + + if src_logmel.size(-1) < mel_frames_per_segment: + src_logmel = F.pad( + src_logmel, + (0, mel_frames_per_segment - src_logmel.size(-1)), + "constant", + src_logmel.min(), + ) + else: + src_logmel = tgt_logmel.clone() + + return wav, src_logmel, tgt_logmel diff --git a/hifigan/discriminator.py b/hifigan/discriminator.py new file mode 100644 index 0000000..8f53c3d --- /dev/null +++ b/hifigan/discriminator.py @@ -0,0 +1,262 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, List + +from hifigan.utils import get_padding + + +LRELU_SLOPE = 0.1 + + +class PeriodDiscriminator(torch.nn.Module): + """HiFiGAN Period Discriminator""" + + def __init__( + self, + period: int, + kernel_size: int = 5, + stride: int = 3, + use_spectral_norm: bool = False, + ) -> None: + super().__init__() + self.period = period + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList( + [ + norm_f( + nn.Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + nn.Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + nn.Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + nn.Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x (Tensor): input waveform. + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + """ + feat = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + + return x, feat + + +class MultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Period Discriminator (MPD)""" + + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + PeriodDiscriminator(2), + PeriodDiscriminator(3), + PeriodDiscriminator(5), + PeriodDiscriminator(7), + PeriodDiscriminator(11), + ] + ) + + def forward( + self, x: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + """ + Args: + x (Tensor): input waveform. + Returns: + [List[Tensor]]: list of scores from each discriminator. + [List[List[Tensor]]]: list of features from each discriminator's convolutional layers. + """ + scores = [] + feats = [] + for _, d in enumerate(self.discriminators): + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class ScaleDiscriminator(torch.nn.Module): + """HiFiGAN Scale Discriminator.""" + + def __init__(self, use_spectral_norm: bool = False) -> None: + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), + norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x (Tensor): input waveform. + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutional layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class MultiScaleDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Scale Discriminator.""" + + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + ScaleDiscriminator(use_spectral_norm=True), + ScaleDiscriminator(), + ScaleDiscriminator(), + ] + ) + self.meanpools = nn.ModuleList( + [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)] + ) + + def forward( + self, x: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + """ + Args: + x (Tensor): input waveform. + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of features from each discriminator's convolutional layers. + """ + scores = [] + feats = [] + for i, d in enumerate(self.discriminators): + if i != 0: + x = self.meanpools[i - 1](x) + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class HifiganDiscriminator(nn.Module): + """HiFiGAN discriminator""" + + def __init__(self): + super().__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward( + self, x: torch.Tensor + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + """ + Args: + x (Tensor): input waveform. + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of features from from each discriminator's convolutional layers. + """ + scores, feats = self.mpd(x) + scores_, feats_ = self.msd(x) + return scores + scores_, feats + feats_ + + +def feature_loss( + features_real: List[List[torch.Tensor]], features_generate: List[List[torch.Tensor]] +) -> float: + loss = 0 + for r, g in zip(features_real, features_generate): + for rl, gl in zip(r, g): + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + +def discriminator_loss(real, generated): + loss = 0 + real_losses = [] + generated_losses = [] + for r, g in zip(real, generated): + r_loss = torch.mean((1 - r) ** 2) + g_loss = torch.mean(g ** 2) + loss += r_loss + g_loss + real_losses.append(r_loss.item()) + generated_losses.append(g_loss.item()) + + return loss, real_losses, generated_losses + + +def generator_loss(discriminator_outputs): + loss = 0 + generator_losses = [] + for x in discriminator_outputs: + l = torch.mean((1 - x) ** 2) + generator_losses.append(l) + loss += l + + return loss, generator_losses diff --git a/hifigan/generator.py b/hifigan/generator.py new file mode 100644 index 0000000..eda2156 --- /dev/null +++ b/hifigan/generator.py @@ -0,0 +1,277 @@ +# adapted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import remove_weight_norm, weight_norm +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present +from typing import Tuple + +from hifigan.utils import get_padding + + +URLS = { + "hifigan": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-67926ec6.pt", + "hifigan-hubert-soft": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-hubert-discrete-bbad3043.pt", + "hifigan-hubert-discrete": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-hubert-soft-65f03469.pt", +} + +LRELU_SLOPE = 0.1 + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels: int = 128, + resblock_dilation_sizes: Tuple[Tuple[int, ...], ...] = ( + (1, 3, 5), + (1, 3, 5), + (1, 3, 5), + ), + resblock_kernel_sizes: Tuple[int, ...] = (3, 7, 11), + upsample_kernel_sizes: Tuple[int, ...] = (20, 8, 4, 4), + upsample_initial_channel: int = 512, + upsample_factors: int = (10, 4, 2, 2), + inference_padding: int = 5, + sample_rate: int = 16000, + ) -> None: + r"""HiFiGAN Generator + Args: + in_channels (int): number of input channels. + resblock_dilation_sizes (Tuple[Tuple[int, ...], ...]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (Tuple[int, ...]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (Tuple[int, ...]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (Tuple[int, ...]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. + sample_rate (int): sample rate of the generated audio. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + self.sample_rate = sample_rate + # initial upsampling layers + self.conv_pre = weight_norm( + nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + ) + + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + nn.ConvTranspose1d( + upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(ResBlock1(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + o = self.conv_pre(x) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def generate(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, (self.inference_padding, self.inference_padding), "replicate") + return self(x), self.sample_rate + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class ResBlock1(torch.nn.Module): + def __init__( + self, channels: int, kernel_size: int = 3, dilation: Tuple[int, ...] = (1, 3, 5) + ) -> None: + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__( + self, channels: int, kernel_size: int = 3, dilation: Tuple[int, ...] = (1, 3) + ) -> None: + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +def _hifigan( + name: str, pretrained: bool = True, progress: bool = True +) -> HifiganGenerator: + hifigan = HifiganGenerator() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress) + consume_prefix_in_state_dict_if_present( + checkpoint["generator"]["model"], "module." + ) + hifigan.load_state_dict(checkpoint) + hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def hifigan(pretrained: bool = True, progress: bool = True) -> HifiganGenerator: + return _hifigan("hifigan", pretrained, progress) + + +def hifigan_hubert_soft( + pretrained: bool = True, progress: bool = True +) -> HifiganGenerator: + return _hifigan("hifigan-hubert-soft", pretrained, progress) + + +def hifigan_hubert_discrete( + pretrained: bool = True, progress: bool = True +) -> HifiganGenerator: + return _hifigan("hifigan-hubert-discrete", pretrained, progress) diff --git a/hifigan/utils.py b/hifigan/utils.py new file mode 100644 index 0000000..3585663 --- /dev/null +++ b/hifigan/utils.py @@ -0,0 +1,84 @@ +import torch +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def save_checkpoint( + checkpoint_dir, + generator, + discriminator, + optimizer_generator, + optimizer_discriminator, + scheduler_generator, + scheduler_discriminator, + step, + loss, + best, + logger, +): + state = { + "generator": { + "model": generator.state_dict(), + "optimizer": optimizer_generator.state_dict(), + "scheduler": scheduler_generator.state_dict(), + }, + "discriminator": { + "model": discriminator.state_dict(), + "optimizer": optimizer_discriminator.state_dict(), + "scheduler": scheduler_discriminator.state_dict(), + }, + "step": step, + "loss": loss, + } + checkpoint_dir.mkdir(exist_ok=True, parents=True) + checkpoint_path = checkpoint_dir / f"model-{step}.pt" + torch.save(state, checkpoint_path) + if best: + best_path = checkpoint_dir / "model-best.pt" + torch.save(state, best_path) + logger.info(f"Saved checkpoint: {checkpoint_path.stem}") + + +def load_checkpoint( + load_path, + generator, + discriminator, + optimizer_generator, + optimizer_discriminator, + scheduler_generator, + scheduler_discriminator, + rank, + logger, + finetune=False, +): + logger.info(f"Loading checkpoint from {load_path}") + checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"}) + generator.load_state_dict(checkpoint["generator"]["model"]) + discriminator.load_state_dict(checkpoint["discriminator"]["model"]) + if not finetune: + optimizer_generator.load_state_dict(checkpoint["generator"]["optimizer"]) + scheduler_generator.load_state_dict(checkpoint["generator"]["scheduler"]) + optimizer_discriminator.load_state_dict( + checkpoint["discriminator"]["optimizer"] + ) + scheduler_discriminator.load_state_dict( + checkpoint["discriminator"]["scheduler"] + ) + return checkpoint["step"], checkpoint["loss"] diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..2f4e3bf --- /dev/null +++ b/hubconf.py @@ -0,0 +1,3 @@ +dependencies = ["torch", "torchaudio"] + +from hifigan.generator import hifigan, hifigan_hubert_discrete, hifigan_hubert_soft diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..20dfcf8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +tensorboard==2.7.0 +torch==1.9.1 +torchaudio==0.9.1 +tqdm==4.62.3 \ No newline at end of file diff --git a/resample.py b/resample.py new file mode 100644 index 0000000..440b682 --- /dev/null +++ b/resample.py @@ -0,0 +1,52 @@ +import argparse +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor +from multiprocessing import cpu_count + +import torchaudio +from torchaudio.functional import resample + +from tqdm import tqdm + + +def process_wav(in_path, out_path, sample_rate): + wav, sr = torchaudio.load(in_path) + wav = resample(wav, sr, sample_rate) + torchaudio.save(out_path, wav, sample_rate) + return out_path, wav.size(-1) / sample_rate + + +def preprocess_dataset(args): + args.out_dir.mkdir(parents=True, exist_ok=True) + + futures = [] + executor = ProcessPoolExecutor(max_workers=cpu_count()) + print(f"Resampling audio in {args.in_dir}") + for in_path in args.in_dir.rglob("*.wav"): + relative_path = in_path.relative_to(args.in_dir) + out_path = args.out_dir / relative_path + out_path.parent.mkdir(parents=True, exist_ok=True) + futures.append( + executor.submit(process_wav, in_path, out_path, args.sample_rate) + ) + + results = [future.result() for future in tqdm(futures)] + + lengths = {path.stem: length for path, length in results} + seconds = sum(lengths.values()) + hours = seconds / 3600 + print(f"Wrote {len(lengths)} utterances ({hours:.2f} hours)") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Resample an audio dataset.") + parser.add_argument("in-dir", help="path to the dataset directory", type=Path) + parser.add_argument("out-dir", help="path to the output directory", type=Path) + parser.add_argument( + "--sample-rate", + help="target sample rate (default 16kHz)", + type=int, + default=16000, + ) + args = parser.parse_args() + preprocess_dataset(args) diff --git a/train.py b/train.py new file mode 100644 index 0000000..cc72fe2 --- /dev/null +++ b/train.py @@ -0,0 +1,343 @@ +import argparse +import logging +from pathlib import Path + +import torch +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +from hifigan.generator import HifiganGenerator +from hifigan.discriminator import ( + HifiganDiscriminator, + feature_loss, + discriminator_loss, + generator_loss, +) +from hifigan.dataset import MelDataset, LogMelSpectrogram +from hifigan.utils import load_checkpoint, save_checkpoint, plot_spectrogram + + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +BATCH_SIZE = 8 +SEGMENT_LENGTH = 8320 +HOP_LENGTH = 160 +SAMPLE_RATE = 16000 +BASE_LEARNING_RATE = 2e-4 +FINETUNE_LEARNING_RATE = 1e-4 +BETAS = (0.8, 0.99) +LEARNING_RATE_DECAY = 0.999 +WEIGHT_DECAY = 1e-5 +EPOCHS = 3100 +LOG_INTERVAL = 5 +VALIDATION_INTERVAL = 1000 +NUM_GENERATED_EXAMPLES = 10 +CHECKPOINT_INTERVAL = 5000 + + +def train_model(rank, world_size, args): + dist.init_process_group( + "nccl", + rank=rank, + world_size=world_size, + init_method="tcp://localhost:54321", + ) + + log_dir = args.checkpoint_dir / "logs" + log_dir.mkdir(exist_ok=True, parents=True) + + if rank == 0: + logger.setLevel(logging.DEBUG) + handler = logging.FileHandler(log_dir / f"{args.checkpoint_dir.stem}.log") + handler.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s [%(levelname)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + else: + logger.setLevel(logging.ERROR) + + writer = SummaryWriter(log_dir) if rank == 0 else None + + generator = HifiganGenerator().to(rank) + discriminator = HifiganDiscriminator().to(rank) + + generator = DDP(generator, device_ids=[rank]) + discriminator = DDP(discriminator, device_ids=[rank]) + + optimizer_generator = optim.AdamW( + generator.parameters(), + lr=BASE_LEARNING_RATE if not args.finetune else FINETUNE_LEARNING_RATE, + betas=BETAS, + weight_decay=WEIGHT_DECAY, + ) + optimizer_discriminator = optim.AdamW( + discriminator.parameters(), + lr=BASE_LEARNING_RATE if not args.finetune else FINETUNE_LEARNING_RATE, + betas=BETAS, + weight_decay=WEIGHT_DECAY, + ) + + scheduler_generator = optim.lr_scheduler.ExponentialLR( + optimizer_generator, gamma=LEARNING_RATE_DECAY + ) + scheduler_discriminator = optim.lr_scheduler.ExponentialLR( + optimizer_discriminator, gamma=LEARNING_RATE_DECAY + ) + + train_dataset = MelDataset( + root=args.dataset_dir, + segment_length=SEGMENT_LENGTH, + sample_rate=SAMPLE_RATE, + hop_length=HOP_LENGTH, + train=True, + finetune=args.finetune, + ) + train_sampler = DistributedSampler(train_dataset, drop_last=True) + train_loader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE, + sampler=train_sampler, + num_workers=8, + pin_memory=True, + shuffle=False, + drop_last=True, + ) + + validation_dataset = MelDataset( + root=args.dataset_dir, + segment_length=SEGMENT_LENGTH, + sample_rate=SAMPLE_RATE, + hop_length=HOP_LENGTH, + train=False, + finetune=args.finetune, + ) + validation_loader = DataLoader( + validation_dataset, + batch_size=1, + shuffle=False, + num_workers=8, + pin_memory=True, + ) + + melspectrogram = LogMelSpectrogram().to(rank) + + if args.resume is not None: + global_step, best_loss = load_checkpoint( + load_path=args.resume, + generator=generator, + discriminator=discriminator, + optimizer_generator=optimizer_generator, + optimizer_discriminator=optimizer_discriminator, + scheduler_generator=scheduler_generator, + scheduler_discriminator=scheduler_discriminator, + rank=rank, + logger=logger, + finetune=args.finetune, + ) + + if args.finetune: + global_step, best_loss = 0, float("inf") + + n_epochs = EPOCHS + start_epoch = global_step // len(train_loader) + 1 + + logger.info("**" * 40) + logger.info(f"batch size: {BATCH_SIZE}") + logger.info(f"iterations per epoch: {len(train_loader)}") + logger.info(f"total of epochs: {n_epochs}") + logger.info(f"started at epoch: {start_epoch}") + logger.info("**" * 40 + "\n") + + for epoch in range(start_epoch, n_epochs + 1): + train_sampler.set_epoch(epoch) + + generator.train() + discriminator.train() + average_loss_mel = average_loss_discriminator = average_loss_generator = 0 + for i, (wavs, mels, tgts) in enumerate(train_loader, 1): + wavs, mels, tgts = wavs.to(rank), mels.to(rank), tgts.to(rank) + + # Discriminator + optimizer_discriminator.zero_grad() + + wavs_ = generator(mels.squeeze(1)) + mels_ = melspectrogram(wavs_) + + scores, _ = discriminator(wavs) + scores_, _ = discriminator(wavs_.detach()) + + loss_discriminator, _, _ = discriminator_loss(scores, scores_) + + loss_discriminator.backward() + optimizer_discriminator.step() + + # Generator + optimizer_generator.zero_grad() + + scores, features = discriminator(wavs) + scores_, features_ = discriminator(wavs_) + + loss_mel = F.l1_loss(mels_, tgts) + loss_features = feature_loss(features, features_) + loss_generator_adversarial, _ = generator_loss(scores_) + loss_generator = 45 * loss_mel + loss_features + loss_generator_adversarial + + loss_generator.backward() + optimizer_generator.step() + + global_step += 1 + + average_loss_mel += (loss_mel.item() - average_loss_mel) / i + average_loss_discriminator += ( + loss_discriminator.item() - average_loss_discriminator + ) / i + average_loss_generator += ( + loss_generator.item() - average_loss_generator + ) / i + + if rank == 0: + if global_step % LOG_INTERVAL == 0: + writer.add_scalar( + "train/loss_mel", + loss_mel.item(), + global_step, + ) + writer.add_scalar( + "train/loss_generator", + loss_generator.item(), + global_step, + ) + writer.add_scalar( + "train/loss_discriminator", + loss_discriminator.item(), + global_step, + ) + + if global_step % VALIDATION_INTERVAL == 0: + generator.eval() + + average_validation_loss = 0 + for j, (wavs, mels, tgts) in enumerate(validation_loader, 1): + wavs, mels, tgts = wavs.to(rank), mels.to(rank), tgts.to(rank) + + with torch.no_grad(): + wavs_ = generator(mels.squeeze(1)) + mels_ = melspectrogram(wavs_) + + length = min(mels_.size(-1), tgts.size(-1)) + + loss_mel = F.l1_loss(mels_[..., :length], tgts[..., :length]) + + average_validation_loss += ( + loss_mel.item() - average_validation_loss + ) / j + + if rank == 0: + if j <= NUM_GENERATED_EXAMPLES: + writer.add_audio( + f"generated/wav_{j}", + wavs_.squeeze(0), + global_step, + sample_rate=16000, + ) + writer.add_figure( + f"generated/mel_{j}", + plot_spectrogram(mels_.squeeze().cpu().numpy()), + global_step, + ) + + generator.train() + discriminator.train() + + if rank == 0: + writer.add_scalar( + "validation/mel_loss", average_validation_loss, global_step + ) + logger.info( + f"valid -- epoch: {epoch}, mel loss: {average_validation_loss:.4f}" + ) + + new_best = best_loss > average_validation_loss + if new_best or global_step % CHECKPOINT_INTERVAL == 0: + if new_best: + logger.info("-------- new best model found!") + best_loss = average_validation_loss + + if rank == 0: + save_checkpoint( + checkpoint_dir=args.checkpoint_dir, + generator=generator, + discriminator=discriminator, + optimizer_generator=optimizer_generator, + optimizer_discriminator=optimizer_discriminator, + scheduler_generator=scheduler_generator, + scheduler_discriminator=scheduler_discriminator, + step=global_step, + loss=average_validation_loss, + best=new_best, + logger=logger, + ) + + scheduler_discriminator.step() + scheduler_generator.step() + + logger.info( + f"train -- epoch: {epoch}, mel loss: {average_loss_mel:.4f}, generator loss: {average_loss_generator:.4f}, discriminator loss: {average_loss_discriminator:.4f}" + ) + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train or finetune HiFi-GAN.") + parser.add_argument( + "dataset-dir", + help="path to the preprocessed data directory", + type=Path, + ) + parser.add_argument( + "checkpoint-dir", + help="path to the checkpoint directory", + type=Path, + ) + parser.add_argument( + "--resume", + help="path to the checkpoint to resume from", + type=Path, + ) + parser.add_argument( + "--finetune", + help="whether to finetune (note that a resume path must be given)", + action="store_true", + ) + args = parser.parse_args() + + # display training setup info + logger.info(f"PyTorch version: {torch.__version__}") + logger.info(f"CUDA version: {torch.version.cuda}") + logger.info(f"CUDNN version: {torch.backends.cudnn.version()}") + logger.info(f"CUDNN enabled: {torch.backends.cudnn.enabled}") + logger.info(f"CUDNN deterministic: {torch.backends.cudnn.deterministic}") + logger.info(f"CUDNN benchmark: {torch.backends.cudnn.benchmark}") + logger.info(f"# of GPUS: {torch.cuda.device_count()}") + + # clear handlers + logger.handlers.clear() + + world_size = torch.cuda.device_count() + mp.spawn( + train_model, + args=(world_size, args), + nprocs=world_size, + join=True, + )