Skip to content

Commit

Permalink
Project the input channels to the dimension of the transformers.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Jul 21, 2024
1 parent 90bb536 commit 2b57b13
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def __init__(

dim = transformer.dim
self.dim = dim
self.to_pred = nn.Linear(dim, dim)

self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, cond_drop_prob = cond_drop_prob)

Expand All @@ -478,6 +477,10 @@ def __init__(
# mel spec

self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = self.mel_spec.n_mel_channels

self.proj_in = nn.Linear(num_channels, dim)
self.to_pred = nn.Linear(dim, num_channels)

# immiscible (diffusion / flow)

Expand All @@ -495,6 +498,7 @@ def transformer_with_pred_head(
text: Int['b nt'] | None = None,
drop_text_cond: bool | None = None
):
x = self.proj_in(x)

if exists(text):
x = self.embed_text(x, text, drop_text_cond = drop_text_cond)
Expand Down

0 comments on commit 2b57b13

Please sign in to comment.