diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index e5351d2..4cb3554 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -189,6 +189,15 @@ def pad_to_length( return t[..., :length] +def interpolate_1d( + x: Tensor, + length: int, + mode = 'bilinear' +): + x = rearrange(x, 'n d -> 1 d n 1') + x = F.interpolate(x, (length, 1), mode = mode) + return rearrange(x, '1 d n 1 -> n d') + # to mel spec class MelSpec(Module): @@ -287,6 +296,7 @@ def forward( self, text: Int['b nt'], max_seq_len: int, + **kwargs ) -> Float['b n d']: text = text + 1 # shift all other token ids up by 1 and use 0 as filler token @@ -296,6 +306,77 @@ def forward( return self.embed(text) +class InterpolatedCharacterEmbed(Module): + def __init__( + self, + dim, + num_embeds = 256, + ): + super().__init__() + self.dim = dim + self.embed = nn.Embedding(num_embeds, dim) + + self.abs_pos_mlp = Sequential( + Rearrange('... -> ... 1'), + Linear(1, dim), + nn.SiLU(), + Linear(dim, dim) + ) + + def forward( + self, + text: Int['b nt'], + max_seq_len: int, + mask: Bool['b n'] | None = None + ) -> Float['b n d']: + + device = text.device + seq = torch.arange(max_seq_len, device = device) + + mask = default(mask, (None,)) + + interp_embeds = [] + interp_abs_positions = [] + + for one_text, one_mask in zip_longest(text, mask): + + valid_text = one_text >= 0 + one_text = one_text[valid_text] + one_text_embed = self.embed(one_text) + + # save the absolute positions + + text_seq_len = one_text.shape[0] + + # determine audio sequence length from mask + + audio_seq_len = max_seq_len + if exists(one_mask): + audio_seq_len = one_mask.sum().long().item() + + # interpolate text embedding to audio embedding length + + interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len) + interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device) + + interp_embeds.append(interp_text_embed) + interp_abs_positions.append(interp_abs_pos) + + interp_embeds = pad_sequence(interp_embeds) + interp_abs_positions = pad_sequence(interp_abs_positions) + + interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2])) + interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len) + + # pass interp absolute positions through mlp for implicit positions + + interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions) + + if exists(mask): + interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.) + + return interp_embeds + # text audio cross conditioning in multistream setup class TextAudioCrossCondition(Module): @@ -303,7 +384,7 @@ def __init__( self, dim, dim_text, - cond_audio_to_text = True + cond_audio_to_text = True, ): super().__init__() self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False) @@ -612,7 +693,10 @@ def __init__( self.num_channels = default(num_channels, self.mel_spec.n_mel_channels) self.transformer = transformer + dim = transformer.dim + dim_text = transformer.dim_text + self.dim = dim self.proj_in = Linear(self.num_channels, self.dim) @@ -630,7 +714,7 @@ def __init__( else: raise ValueError(f'unknown tokenizer string {tokenizer}') - self.embed_text = CharacterEmbed(transformer.dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) + self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) # to prediction @@ -726,6 +810,7 @@ def __init__( mel_spec_kwargs: dict = dict(), frac_lengths_mask: tuple[float, float] = (0.7, 1.), concat_cond = False, + interpolated_text = False, text_num_embeds: int | None = None, tokenizer: ( Literal['char_utf8', 'phoneme_en'] | @@ -801,7 +886,9 @@ def __init__( # text embedding - self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) + text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed + + self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs) # default vocos for mel -> audio @@ -840,7 +927,7 @@ def transformer_with_pred_head( text_embed = None if exists(text) and not drop_text_cond: - text_embed = self.embed_text(text, seq_len) + text_embed = self.embed_text(text, seq_len, mask = mask) # attend diff --git a/pyproject.toml b/pyproject.toml index 5d7d752..d2db022 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "1.0.4" +version = "1.0.5" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }