diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py index 3c0e3a3a76..981d6cdb1f 100644 --- a/TTS/tts/layers/delightful_tts/acoustic_model.py +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -12,7 +12,6 @@ from TTS.tts.layers.delightful_tts.encoders import ( PhonemeLevelProsodyEncoder, UtteranceLevelProsodyEncoder, - get_mask_from_lengths, ) from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding @@ -20,7 +19,7 @@ from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor from TTS.tts.layers.generic.aligner import AlignmentNetwork -from TTS.tts.utils.helpers import generate_path, sequence_mask +from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask logger = logging.getLogger(__name__) @@ -231,42 +230,6 @@ def _init_d_vector(self): raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") self.embedded_speaker_dim = self.args.d_vector_dim - @staticmethod - def generate_attn(dr, x_mask, y_mask=None): - """Generate an attention mask from the linear scale durations. - - Args: - dr (Tensor): Linear scale durations. - x_mask (Tensor): Mask for the input (character) sequence. - y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations - if None. Defaults to None. - - Shapes - - dr: :math:`(B, T_{en})` - - x_mask: :math:`(B, T_{en})` - - y_mask: :math:`(B, T_{de})` - """ - # compute decode mask from the durations - if y_mask is None: - y_lengths = dr.sum(1).long() - y_lengths[y_lengths < 1] = 1 - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) - return attn - - def _expand_encoder_with_durations( - self, - o_en: torch.FloatTensor, - dr: torch.IntTensor, - x_mask: torch.IntTensor, - y_lengths: torch.IntTensor, - ): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) - attn = self.generate_attn(dr, x_mask, y_mask) - o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en]) - return y_mask, o_en_ex, attn.transpose(1, 2) - def _forward_aligner( self, x: torch.FloatTensor, @@ -340,8 +303,8 @@ def forward( {"d_vectors": d_vectors, "speaker_ids": speaker_idx} ) # pylint: disable=unused-variable - src_mask = get_mask_from_lengths(src_lens) # [B, T_src] - mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel] + src_mask = ~sequence_mask(src_lens) # [B, T_src] + mel_mask = ~sequence_mask(mel_lens) # [B, T_mel] # Token embeddings token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden] @@ -420,8 +383,8 @@ def forward( encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask) - mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( - o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None] + encoder_outputs_ex, alignments, mel_pred_mask = expand_encoder_outputs( + encoder_outputs, y_lengths=mel_lens, duration=dr, x_mask=~src_mask[:, None] ) x = self.decoder( @@ -435,7 +398,7 @@ def forward( dr = torch.log(dr + 1) dr_pred = torch.exp(log_duration_prediction) - 1 - alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2'] + alignments_dp = generate_attention(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2'] return { "model_outputs": x, @@ -448,7 +411,7 @@ def forward( "p_prosody_pred": p_prosody_pred, "p_prosody_ref": p_prosody_ref, "alignments_dp": alignments_dp, - "alignments": alignments, # [B, T_de, T_en] + "alignments": alignments.transpose(1, 2), # [B, T_de, T_en] "aligner_soft": aligner_soft, "aligner_mas": aligner_mas, "aligner_durations": aligner_durations, @@ -469,7 +432,7 @@ def inference( pitch_transform: Callable = None, energy_transform: Callable = None, ) -> torch.Tensor: - src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device)) + src_mask = ~sequence_mask(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device)) src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable {"d_vectors": d_vectors, "speaker_ids": speaker_idx} @@ -536,11 +499,11 @@ def inference( duration_pred = torch.round(duration_pred) # -> [B, T_src] mel_lens = duration_pred.sum(1) # -> [B,] - _, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( - o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None] + encoder_outputs_ex, alignments, _ = expand_encoder_outputs( + encoder_outputs, y_lengths=mel_lens, duration=duration_pred.squeeze(1), x_mask=~src_mask[:, None] ) - mel_mask = get_mask_from_lengths( + mel_mask = ~sequence_mask( torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device) ) @@ -557,7 +520,7 @@ def inference( x = self.to_mel(x) outputs = { "model_outputs": x, - "alignments": alignments, + "alignments": alignments.transpose(1, 2), # "pitch": pitch_emb_pred, "durations": duration_pred, "pitch": pitch_pred, diff --git a/TTS/tts/layers/delightful_tts/encoders.py b/TTS/tts/layers/delightful_tts/encoders.py index 0878f0677a..bd0c319dc1 100644 --- a/TTS/tts/layers/delightful_tts/encoders.py +++ b/TTS/tts/layers/delightful_tts/encoders.py @@ -7,14 +7,7 @@ from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d from TTS.tts.layers.delightful_tts.networks import STL - - -def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: - batch_size = lengths.shape[0] - max_len = torch.max(lengths).item() - ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) - mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) - return mask +from TTS.tts.utils.helpers import sequence_mask def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: @@ -93,7 +86,7 @@ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor outputs --- [N, E//2] """ - mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1) + mel_masks = ~sequence_mask(mel_lens).unsqueeze(1) x = x.masked_fill(mel_masks, 0) for conv, norm in zip(self.convs, self.norms): x = conv(x) @@ -103,7 +96,7 @@ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor for _ in range(2): mel_lens = stride_lens(mel_lens) - mel_masks = get_mask_from_lengths(mel_lens) + mel_masks = ~sequence_mask(mel_lens) x = x.masked_fill(mel_masks.unsqueeze(1), 0) x = x.permute((0, 2, 1)) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 1c3d57582e..28a52bc558 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -13,7 +13,7 @@ from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import generate_path, sequence_mask +from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -169,35 +169,6 @@ def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): dr_mas = torch.sum(attn, -1) return dr_mas.squeeze(1), log_p - @staticmethod - def generate_attn(dr, x_mask, y_mask=None): - # compute decode mask from the durations - if y_mask is None: - y_lengths = dr.sum(1).long() - y_lengths[y_lengths < 1] = 1 - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) - return attn - - def expand_encoder_outputs(self, en, dr, x_mask, y_mask): - """Generate attention alignment map from durations and - expand encoder outputs - - Examples:: - - encoder output: [a,b,c,d] - - durations: [1, 3, 2, 1] - - - expanded: [a, b, b, b, c, c, d] - - attention map: [[0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0]] - """ - attn = self.generate_attn(dr, x_mask, y_mask) - o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) - return o_en_ex, attn - def format_durations(self, o_dr_log, x_mask): o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale o_dr[o_dr < 1] = 1.0 @@ -243,9 +214,8 @@ def _forward_encoder(self, x, x_lengths, g=None): return o_en, o_en_dp, x_mask, g def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) # expand o_en with durations - o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) @@ -282,7 +252,7 @@ def forward( o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) - attn = self.generate_attn(dr_mas, x_mask, y_mask) + attn = generate_attention(dr_mas, x_mask, y_mask) elif phase == 1: # train decoder o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index d449e580da..d09e3ea91b 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -14,7 +14,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import average_over_durations, generate_path, sequence_mask +from TTS.tts.utils.helpers import average_over_durations, expand_encoder_outputs, generate_attention, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram @@ -310,49 +310,6 @@ def init_multispeaker(self, config: Coqpit): self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) - @staticmethod - def generate_attn(dr, x_mask, y_mask=None): - """Generate an attention mask from the durations. - - Shapes - - dr: :math:`(B, T_{en})` - - x_mask: :math:`(B, T_{en})` - - y_mask: :math:`(B, T_{de})` - """ - # compute decode mask from the durations - if y_mask is None: - y_lengths = dr.sum(1).long() - y_lengths[y_lengths < 1] = 1 - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) - return attn - - def expand_encoder_outputs(self, en, dr, x_mask, y_mask): - """Generate attention alignment map from durations and - expand encoder outputs - - Shapes: - - en: :math:`(B, D_{en}, T_{en})` - - dr: :math:`(B, T_{en})` - - x_mask: :math:`(B, T_{en})` - - y_mask: :math:`(B, T_{de})` - - Examples:: - - encoder output: [a,b,c,d] - durations: [1, 3, 2, 1] - - expanded: [a, b, b, b, c, c, d] - attention map: [[0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0]] - """ - attn = self.generate_attn(dr, x_mask, y_mask) - o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) - return o_en_ex, attn - def format_durations(self, o_dr_log, x_mask): """Format predicted durations. 1. Convert to linear scale from log scale @@ -443,9 +400,8 @@ def _forward_decoder( Returns: Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. """ - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) # expand o_en with durations - o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) @@ -624,7 +580,7 @@ def forward( o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # generate attn mask from predicted durations - o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + o_attn = generate_attention(o_dr.squeeze(1), x_mask) # aligner o_alignment_dur = None alignment_soft = None diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index d1722501f7..ff10f751f2 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import torch from scipy.stats import betabinom @@ -33,7 +35,7 @@ def inverse_transform(self, X): # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 -def sequence_mask(sequence_length, max_len=None): +def sequence_mask(sequence_length: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor: """Create a sequence mask for filtering padding in a sequence tensor. Args: @@ -44,7 +46,7 @@ def sequence_mask(sequence_length, max_len=None): - mask: :math:`[B, T_max]` """ if max_len is None: - max_len = sequence_length.max() + max_len = int(sequence_length.max()) seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) # B x T_max return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) @@ -143,22 +145,75 @@ def convert_pad_shape(pad_shape: list[list]) -> list: return [item for sublist in l for item in sublist] -def generate_path(duration, mask): - """ +def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate alignment path based on the given segment durations. + Shapes: - duration: :math:`[B, T_en]` - mask: :math:'[B, T_en, T_de]` - path: :math:`[B, T_en, T_de]` """ b, t_x, t_y = mask.shape - cum_duration = torch.cumsum(duration, 1) + cum_duration = torch.cumsum(duration, dim=1) cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path * mask - return path + return path * mask + + +def generate_attention( + duration: torch.Tensor, x_mask: torch.Tensor, y_mask: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Generate an attention map from the linear scale durations. + + Args: + duration (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations + if None. Defaults to None. + + Shapes + - duration: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = duration.sum(dim=1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = sequence_mask(y_lengths).unsqueeze(1).to(duration.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + return generate_path(duration, attn_mask.squeeze(1)).to(duration.dtype) + + +def expand_encoder_outputs( + x: torch.Tensor, duration: torch.Tensor, x_mask: torch.Tensor, y_lengths: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate attention alignment map from durations and expand encoder outputs. + + Shapes: + - x: Encoder output :math:`(B, D_{en}, T_{en})` + - duration: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_lengths: :math:`(B)` + + Examples:: + + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + y_mask = sequence_mask(y_lengths).unsqueeze(1).to(x.dtype) + attn = generate_attention(duration, x_mask, y_mask) + x_expanded = torch.einsum("kmn, kjm -> kjn", [attn.float(), x]) + return x_expanded, attn, y_mask def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): diff --git a/tests/aux_tests/test_helpers.py b/tests/aux_tests/test_helpers.py index d07efa3620..6781cbc5d4 100644 --- a/tests/aux_tests/test_helpers.py +++ b/tests/aux_tests/test_helpers.py @@ -1,6 +1,14 @@ import torch as T -from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.helpers import ( + average_over_durations, + expand_encoder_outputs, + generate_attention, + generate_path, + rand_segments, + segment, + sequence_mask, +) def test_average_over_durations(): # pylint: disable=no-self-use @@ -86,3 +94,24 @@ def test_generate_path(): assert all(path[b, t, :current_idx] == 0.0) assert all(path[b, t, current_idx + durations[b, t].item() :] == 0.0) current_idx += durations[b, t].item() + + assert T.all(path == generate_attention(durations, x_mask, y_mask)) + assert T.all(path == generate_attention(durations, x_mask)) + + +def test_expand_encoder_outputs(): + inputs = T.rand(2, 5, 57) + durations = T.randint(1, 4, (2, 57)) + + x_mask = T.ones(2, 1, 57) + y_lengths = T.ones(2) * durations.sum(1).max() + + expanded, _, _ = expand_encoder_outputs(inputs, durations, x_mask, y_lengths) + + for b in range(durations.shape[0]): + index = 0 + for idx, dur in enumerate(durations[b]): + idx_expanded = expanded[b, :, index : index + dur.item()] + diff = (idx_expanded - inputs[b, :, idx].repeat(int(dur)).view(idx_expanded.shape)).sum() + assert abs(diff) < 1e-6, diff + index += dur diff --git a/tests/tts_tests2/test_forward_tts.py b/tests/tts_tests2/test_forward_tts.py index cec0f211c8..13a2c270af 100644 --- a/tests/tts_tests2/test_forward_tts.py +++ b/tests/tts_tests2/test_forward_tts.py @@ -6,29 +6,7 @@ # pylint: disable=unused-variable -def expand_encoder_outputs_test(): - model = ForwardTTS(ForwardTTSArgs(num_chars=10)) - - inputs = T.rand(2, 5, 57) - durations = T.randint(1, 4, (2, 57)) - - x_mask = T.ones(2, 1, 57) - y_mask = T.ones(2, 1, durations.sum(1).max()) - - expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask) - - for b in range(durations.shape[0]): - index = 0 - for idx, dur in enumerate(durations[b]): - diff = ( - expanded[b, :, index : index + dur.item()] - - inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape) - ).sum() - assert abs(diff) < 1e-6, diff - index += dur - - -def model_input_output_test(): +def test_model_input_output(): """Assert the output shapes of the model in different modes""" # VANILLA MODEL