diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index a10ddb3..fcbdc89 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -83,14 +83,17 @@ def __init__( super().__init__() self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token' self.combine = nn.Linear(dim * 2, dim) - self.cond_drop_prob + self.cond_drop_prob = cond_drop_prob def forward( self, x: Float['b n d'], text: Int['b n'], + drop_text_cond = None ): - if self.training and random() < self.cond_drop_prob: + drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob) + + if drop_text_cond: return x max_seq_len = x.shape[1] @@ -356,6 +359,25 @@ def __init__( def device(self): return next(self.parameters()).device + def transformer_with_pred_head( + self, + x: Float['b n d'], + times: Float['b'], + mask: Bool['b n'] | None = None, + text: Int['b nt'] | None = None + ): + if exists(text): + x = self.embed_text(x, text) + + attended = self.transformer( + x, + times = times, + mask = mask + ) + + pred = self.to_pred(attended) + return pred + @torch.no_grad() def sample( self, @@ -407,7 +429,7 @@ def fn(t, x): # predict flow - return self.transformer( + return self.transformer_with_pred_head( x, times = t, mask = mask @@ -425,7 +447,7 @@ def forward( self, inp: Float['b n d'], # is mel in paper *, - text: Int['b n'] | None = None, + text: Int['b nt'] | None = None, times: Int['b'] | None = None, lens: Int['b'] | None = None, ): @@ -436,11 +458,6 @@ def forward( mask = lens_to_mask(lens, length = seq_len) - # text - - if exists(text): - inp = self.embed_text(inp, text) - # get a random span to mask out for training conditionally random_span_frac_indices = inp.new_zeros(2, batch).uniform_(0, 1) @@ -485,13 +502,7 @@ def forward( # transformer and prediction head - attended = self.transformer( - w, - times = times, - mask = mask - ) - - pred = self.to_pred(attended) + pred = self.transformer_with_pred_head(w, times = times, text = text) # flow matching loss diff --git a/pyproject.toml b/pyproject.toml index 6bb24e2..4423c62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.0.10" +version = "0.0.11" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }