Skip to content

Commit

Permalink
move transformer with prediction head to own function for classifier …
Browse files Browse the repository at this point in the history
…free guidance
  • Loading branch information
lucidrains committed Jul 13, 2024
1 parent 5c7d202 commit c7d4da4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
43 changes: 27 additions & 16 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,17 @@ def __init__(
super().__init__()
self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token'
self.combine = nn.Linear(dim * 2, dim)
self.cond_drop_prob
self.cond_drop_prob = cond_drop_prob

def forward(
self,
x: Float['b n d'],
text: Int['b n'],
drop_text_cond = None
):
if self.training and random() < self.cond_drop_prob:
drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob)

if drop_text_cond:
return x

max_seq_len = x.shape[1]
Expand Down Expand Up @@ -356,6 +359,25 @@ def __init__(
def device(self):
return next(self.parameters()).device

def transformer_with_pred_head(
self,
x: Float['b n d'],
times: Float['b'],
mask: Bool['b n'] | None = None,
text: Int['b nt'] | None = None
):
if exists(text):
x = self.embed_text(x, text)

attended = self.transformer(
x,
times = times,
mask = mask
)

pred = self.to_pred(attended)
return pred

@torch.no_grad()
def sample(
self,
Expand Down Expand Up @@ -407,7 +429,7 @@ def fn(t, x):

# predict flow

return self.transformer(
return self.transformer_with_pred_head(
x,
times = t,
mask = mask
Expand All @@ -425,7 +447,7 @@ def forward(
self,
inp: Float['b n d'], # is mel in paper
*,
text: Int['b n'] | None = None,
text: Int['b nt'] | None = None,
times: Int['b'] | None = None,
lens: Int['b'] | None = None,
):
Expand All @@ -436,11 +458,6 @@ def forward(

mask = lens_to_mask(lens, length = seq_len)

# text

if exists(text):
inp = self.embed_text(inp, text)

# get a random span to mask out for training conditionally

random_span_frac_indices = inp.new_zeros(2, batch).uniform_(0, 1)
Expand Down Expand Up @@ -485,13 +502,7 @@ def forward(

# transformer and prediction head

attended = self.transformer(
w,
times = times,
mask = mask
)

pred = self.to_pred(attended)
pred = self.transformer_with_pred_head(w, times = times, text = text)

# flow matching loss

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.0.10"
version = "0.0.11"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit c7d4da4

Please sign in to comment.