Skip to content

Commit

Permalink
refactor: move duplicate alignment functions into helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Nov 24, 2024
1 parent 90087b5 commit 950a9a2
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 170 deletions.
61 changes: 12 additions & 49 deletions TTS/tts/layers/delightful_tts/acoustic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from TTS.tts.layers.delightful_tts.encoders import (
PhonemeLevelProsodyEncoder,
UtteranceLevelProsodyEncoder,
get_mask_from_lengths,
)
from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor
from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding
from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProsodyPredictor
from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor
from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
from TTS.tts.layers.generic.aligner import AlignmentNetwork
from TTS.tts.utils.helpers import generate_path, sequence_mask
from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -231,42 +230,6 @@ def _init_d_vector(self):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
self.embedded_speaker_dim = self.args.d_vector_dim

@staticmethod
def generate_attn(dr, x_mask, y_mask=None):
"""Generate an attention mask from the linear scale durations.
Args:
dr (Tensor): Linear scale durations.
x_mask (Tensor): Mask for the input (character) sequence.
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
if None. Defaults to None.
Shapes
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
"""
# compute decode mask from the durations
if y_mask is None:
y_lengths = dr.sum(1).long()
y_lengths[y_lengths < 1] = 1
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn

def _expand_encoder_with_durations(
self,
o_en: torch.FloatTensor,
dr: torch.IntTensor,
x_mask: torch.IntTensor,
y_lengths: torch.IntTensor,
):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
attn = self.generate_attn(dr, x_mask, y_mask)
o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en])
return y_mask, o_en_ex, attn.transpose(1, 2)

def _forward_aligner(
self,
x: torch.FloatTensor,
Expand Down Expand Up @@ -340,8 +303,8 @@ def forward(
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
) # pylint: disable=unused-variable

src_mask = get_mask_from_lengths(src_lens) # [B, T_src]
mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel]
src_mask = ~sequence_mask(src_lens) # [B, T_src]
mel_mask = ~sequence_mask(mel_lens) # [B, T_mel]

# Token embeddings
token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden]
Expand Down Expand Up @@ -420,8 +383,8 @@ def forward(
encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb
log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask)

mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None]
encoder_outputs_ex, alignments, mel_pred_mask = expand_encoder_outputs(
encoder_outputs, y_lengths=mel_lens, duration=dr, x_mask=~src_mask[:, None]
)

x = self.decoder(
Expand All @@ -435,7 +398,7 @@ def forward(
dr = torch.log(dr + 1)

dr_pred = torch.exp(log_duration_prediction) - 1
alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']
alignments_dp = generate_attention(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2']

return {
"model_outputs": x,
Expand All @@ -448,7 +411,7 @@ def forward(
"p_prosody_pred": p_prosody_pred,
"p_prosody_ref": p_prosody_ref,
"alignments_dp": alignments_dp,
"alignments": alignments, # [B, T_de, T_en]
"alignments": alignments.transpose(1, 2), # [B, T_de, T_en]
"aligner_soft": aligner_soft,
"aligner_mas": aligner_mas,
"aligner_durations": aligner_durations,
Expand All @@ -469,7 +432,7 @@ def inference(
pitch_transform: Callable = None,
energy_transform: Callable = None,
) -> torch.Tensor:
src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
src_mask = ~sequence_mask(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device))
src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable
sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable
{"d_vectors": d_vectors, "speaker_ids": speaker_idx}
Expand Down Expand Up @@ -536,11 +499,11 @@ def inference(
duration_pred = torch.round(duration_pred) # -> [B, T_src]
mel_lens = duration_pred.sum(1) # -> [B,]

_, encoder_outputs_ex, alignments = self._expand_encoder_with_durations(
o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
encoder_outputs_ex, alignments, _ = expand_encoder_outputs(
encoder_outputs, y_lengths=mel_lens, duration=duration_pred.squeeze(1), x_mask=~src_mask[:, None]
)

mel_mask = get_mask_from_lengths(
mel_mask = ~sequence_mask(
torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device)
)

Expand All @@ -557,7 +520,7 @@ def inference(
x = self.to_mel(x)
outputs = {
"model_outputs": x,
"alignments": alignments,
"alignments": alignments.transpose(1, 2),
# "pitch": pitch_emb_pred,
"durations": duration_pred,
"pitch": pitch_pred,
Expand Down
13 changes: 3 additions & 10 deletions TTS/tts/layers/delightful_tts/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d
from TTS.tts.layers.delightful_tts.networks import STL


def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
batch_size = lengths.shape[0]
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
return mask
from TTS.tts.utils.helpers import sequence_mask


def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
Expand Down Expand Up @@ -93,7 +86,7 @@ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor
outputs --- [N, E//2]
"""

mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
mel_masks = ~sequence_mask(mel_lens).unsqueeze(1)
x = x.masked_fill(mel_masks, 0)
for conv, norm in zip(self.convs, self.norms):
x = conv(x)
Expand All @@ -103,7 +96,7 @@ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor
for _ in range(2):
mel_lens = stride_lens(mel_lens)

mel_masks = get_mask_from_lengths(mel_lens)
mel_masks = ~sequence_mask(mel_lens)

x = x.masked_fill(mel_masks.unsqueeze(1), 0)
x = x.permute((0, 2, 1))
Expand Down
36 changes: 3 additions & 33 deletions TTS/tts/models/align_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, sequence_mask
from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
Expand Down Expand Up @@ -169,35 +169,6 @@ def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask):
dr_mas = torch.sum(attn, -1)
return dr_mas.squeeze(1), log_p

@staticmethod
def generate_attn(dr, x_mask, y_mask=None):
# compute decode mask from the durations
if y_mask is None:
y_lengths = dr.sum(1).long()
y_lengths[y_lengths < 1] = 1
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn

def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
"""Generate attention alignment map from durations and
expand encoder outputs
Examples::
- encoder output: [a,b,c,d]
- durations: [1, 3, 2, 1]
- expanded: [a, b, b, b, c, c, d]
- attention map: [[0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
"""
attn = self.generate_attn(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn

def format_durations(self, o_dr_log, x_mask):
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
o_dr[o_dr < 1] = 1.0
Expand Down Expand Up @@ -243,9 +214,8 @@ def _forward_encoder(self, x, x_lengths, g=None):
return o_en, o_en_dp, x_mask, g

def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths)
# positional encoding
if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
Expand Down Expand Up @@ -282,7 +252,7 @@ def forward(
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
attn = self.generate_attn(dr_mas, x_mask, y_mask)
attn = generate_attention(dr_mas, x_mask, y_mask)
elif phase == 1:
# train decoder
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
Expand Down
50 changes: 3 additions & 47 deletions TTS/tts/models/forward_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, sequence_mask
from TTS.tts.utils.helpers import average_over_durations, expand_encoder_outputs, generate_attention, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
Expand Down Expand Up @@ -310,49 +310,6 @@ def init_multispeaker(self, config: Coqpit):
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)

@staticmethod
def generate_attn(dr, x_mask, y_mask=None):
"""Generate an attention mask from the durations.
Shapes
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
"""
# compute decode mask from the durations
if y_mask is None:
y_lengths = dr.sum(1).long()
y_lengths[y_lengths < 1] = 1
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn

def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
"""Generate attention alignment map from durations and
expand encoder outputs
Shapes:
- en: :math:`(B, D_{en}, T_{en})`
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
Examples::
encoder output: [a,b,c,d]
durations: [1, 3, 2, 1]
expanded: [a, b, b, b, c, c, d]
attention map: [[0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
"""
attn = self.generate_attn(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn

def format_durations(self, o_dr_log, x_mask):
"""Format predicted durations.
1. Convert to linear scale from log scale
Expand Down Expand Up @@ -443,9 +400,8 @@ def _forward_decoder(
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
"""
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths)
# positional encoding
if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
Expand Down Expand Up @@ -624,7 +580,7 @@ def forward(
o_dr_log = self.duration_predictor(o_en, x_mask)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# generate attn mask from predicted durations
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
o_attn = generate_attention(o_dr.squeeze(1), x_mask)
# aligner
o_alignment_dur = None
alignment_soft = None
Expand Down
Loading

0 comments on commit 950a9a2

Please sign in to comment.