Skip to content

Commit

Permalink
add back interpolated text for @manmay-nakhashi
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 11, 2024
1 parent d07a255 commit 85a8c0b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 5 deletions.
95 changes: 91 additions & 4 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def pad_to_length(

return t[..., :length]

def interpolate_1d(
x: Tensor,
length: int,
mode = 'bilinear'
):
x = rearrange(x, 'n d -> 1 d n 1')
x = F.interpolate(x, (length, 1), mode = mode)
return rearrange(x, '1 d n 1 -> n d')

# to mel spec

class MelSpec(Module):
Expand Down Expand Up @@ -287,6 +296,7 @@ def forward(
self,
text: Int['b nt'],
max_seq_len: int,
**kwargs
) -> Float['b n d']:

text = text + 1 # shift all other token ids up by 1 and use 0 as filler token
Expand All @@ -296,14 +306,85 @@ def forward(

return self.embed(text)

class InterpolatedCharacterEmbed(Module):
def __init__(
self,
dim,
num_embeds = 256,
):
super().__init__()
self.dim = dim
self.embed = nn.Embedding(num_embeds, dim)

self.abs_pos_mlp = Sequential(
Rearrange('... -> ... 1'),
Linear(1, dim),
nn.SiLU(),
Linear(dim, dim)
)

def forward(
self,
text: Int['b nt'],
max_seq_len: int,
mask: Bool['b n'] | None = None
) -> Float['b n d']:

device = text.device
seq = torch.arange(max_seq_len, device = device)

mask = default(mask, (None,))

interp_embeds = []
interp_abs_positions = []

for one_text, one_mask in zip_longest(text, mask):

valid_text = one_text >= 0
one_text = one_text[valid_text]
one_text_embed = self.embed(one_text)

# save the absolute positions

text_seq_len = one_text.shape[0]

# determine audio sequence length from mask

audio_seq_len = max_seq_len
if exists(one_mask):
audio_seq_len = one_mask.sum().long().item()

# interpolate text embedding to audio embedding length

interp_text_embed = interpolate_1d(one_text_embed, audio_seq_len)
interp_abs_pos = torch.linspace(0, text_seq_len, audio_seq_len, device = device)

interp_embeds.append(interp_text_embed)
interp_abs_positions.append(interp_abs_pos)

interp_embeds = pad_sequence(interp_embeds)
interp_abs_positions = pad_sequence(interp_abs_positions)

interp_embeds = F.pad(interp_embeds, (0, 0, 0, max_seq_len - interp_embeds.shape[-2]))
interp_abs_positions = pad_to_length(interp_abs_positions, max_seq_len)

# pass interp absolute positions through mlp for implicit positions

interp_embeds = interp_embeds + self.abs_pos_mlp(interp_abs_positions)

if exists(mask):
interp_embeds = einx.where('b n, b n d, -> b n d', mask, interp_embeds, 0.)

return interp_embeds

# text audio cross conditioning in multistream setup

class TextAudioCrossCondition(Module):
def __init__(
self,
dim,
dim_text,
cond_audio_to_text = True
cond_audio_to_text = True,
):
super().__init__()
self.text_to_audio = nn.Linear(dim_text + dim, dim, bias = False)
Expand Down Expand Up @@ -612,7 +693,10 @@ def __init__(
self.num_channels = default(num_channels, self.mel_spec.n_mel_channels)

self.transformer = transformer

dim = transformer.dim
dim_text = transformer.dim_text

self.dim = dim

self.proj_in = Linear(self.num_channels, self.dim)
Expand All @@ -630,7 +714,7 @@ def __init__(
else:
raise ValueError(f'unknown tokenizer string {tokenizer}')

self.embed_text = CharacterEmbed(transformer.dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)
self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)

# to prediction

Expand Down Expand Up @@ -726,6 +810,7 @@ def __init__(
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
concat_cond = False,
interpolated_text = False,
text_num_embeds: int | None = None,
tokenizer: (
Literal['char_utf8', 'phoneme_en'] |
Expand Down Expand Up @@ -801,7 +886,9 @@ def __init__(

# text embedding

self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)
text_embed_klass = CharacterEmbed if not interpolated_text else InterpolatedCharacterEmbed

self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)

# default vocos for mel -> audio

Expand Down Expand Up @@ -840,7 +927,7 @@ def transformer_with_pred_head(

text_embed = None
if exists(text) and not drop_text_cond:
text_embed = self.embed_text(text, seq_len)
text_embed = self.embed_text(text, seq_len, mask = mask)

# attend

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.0.4"
version = "1.0.5"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 85a8c0b

Please sign in to comment.