Skip to content

Commit

Permalink
refactor: move more audio processing into torch_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Nov 23, 2024
1 parent 5edf124 commit a666904
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 198 deletions.
78 changes: 2 additions & 76 deletions TTS/tts/models/delightful_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
from TTS.utils.audio.processor import AudioProcessor
from TTS.utils.audio.torch_transforms import amp_to_db
from TTS.utils.audio.torch_transforms import amp_to_db, wav_to_spec
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
Expand All @@ -50,62 +50,8 @@
mel_basis = {}


def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1)

if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))

global hann_window # pylint: disable=global-statement
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_length) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)

y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)

spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

return spec


def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec


def wav_to_energy(y, n_fft, hop_length, win_length, center=False):
spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center)
return torch.norm(spec, dim=1, keepdim=True)


Expand All @@ -114,26 +60,6 @@ def name_mel_basis(spec, n_fft, fmax):
return n_fft_len


def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
"""
Args Shapes:
- spec : :math:`[B,C,T]`
Return Shapes:
- mel : :math:`[B,C,T]`
"""
global mel_basis # pylint: disable=global-statement
mel_basis_key = name_mel_basis(spec, n_fft, fmax)
# pylint: disable=too-many-function-args
if mel_basis_key not in mel_basis:
# pylint: disable=missing-kwoa
mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax)
mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = torch.matmul(mel_basis[mel_basis_key], spec)
mel = amp_to_db(mel)
return mel


def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
"""
Args Shapes:
Expand Down
72 changes: 1 addition & 71 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.audio.torch_transforms import amp_to_db
from TTS.utils.audio.torch_transforms import amp_to_db, spec_to_mel, wav_to_spec
from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
Expand All @@ -46,10 +46,6 @@
# IO / Feature extraction
##############################

# pylint: disable=global-statement
hann_window = {}
mel_basis = {}


@torch.no_grad()
def weights_reset(m: nn.Module):
Expand Down Expand Up @@ -79,72 +75,6 @@ def load_audio(file_path):
return x, sr


def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
y = y.squeeze(1)

if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))

global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_length) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)

y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)

spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec


def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
"""
Args Shapes:
- spec : :math:`[B,C,T]`
Return Shapes:
- mel : :math:`[B,C,T]`
"""
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
mel = amp_to_db(mel)
return mel


def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
"""
Args Shapes:
Expand Down
73 changes: 73 additions & 0 deletions TTS/utils/audio/torch_transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import logging

import librosa
import torch
from torch import nn

logger = logging.getLogger(__name__)


hann_window = {}
mel_basis = {}


def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor:
"""Spectral normalization / dynamic range compression."""
Expand All @@ -13,6 +21,71 @@ def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor:
return torch.exp(x) / spec_gain


def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor:
"""
Args Shapes:
- y : :math:`[B, 1, T]`
Return Shapes:
- spec : :math:`[B,C,T]`
"""
y = y.squeeze(1)

if torch.min(y) < -1.0:
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("max value is %.3f", torch.max(y))

global hann_window
wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}"
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)

y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
y = y.squeeze(1)

spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

return torch.sqrt(spec.pow(2).sum(-1) + 1e-6)


def spec_to_mel(
spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float
) -> torch.Tensor:
"""
Args Shapes:
- spec : :math:`[B,C,T]`
Return Shapes:
- mel : :math:`[B,C,T]`
"""
global mel_basis
fmax_dtype_device = f"{fmax}_{spec.dtype}_{spec.device}"
if fmax_dtype_device not in mel_basis:
# TODO: switch librosa to torchaudio
mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
return amp_to_db(mel)


class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""Some of the audio processing funtions using Torch for faster batch processing.
Expand Down
48 changes: 0 additions & 48 deletions TTS/vc/modules/freevc/mel_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,6 @@
hann_window = {}


def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.0:
logger.info("Min value is: %.3f", torch.min(y))
if torch.max(y) > 1.0:
logger.info("Max value is: %.3f", torch.max(y))

global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)

y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)

spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec


def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = amp_to_db(spec)
return spec


def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
logger.info("Min value is: %.3f", torch.min(y))
Expand Down
4 changes: 1 addition & 3 deletions tests/tts_tests/test_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
VitsArgs,
VitsAudioConfig,
load_audio,
spec_to_mel,
wav_to_mel,
wav_to_spec,
)
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp
from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp, spec_to_mel, wav_to_spec

LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
Expand Down

0 comments on commit a666904

Please sign in to comment.