Skip to content

Commit

Permalink
Merge pull request #11 from lucasnewman/inout-projection
Browse files Browse the repository at this point in the history
Project the input to the dimension of the transformer
  • Loading branch information
lucidrains authored Jul 21, 2024
2 parents e02c233 + 2b57b13 commit d7f6741
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,6 @@ def __init__(

dim = transformer.dim
self.dim = dim
self.to_pred = nn.Linear(dim, dim)

self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, cond_drop_prob = cond_drop_prob)

Expand All @@ -479,6 +478,10 @@ def __init__(
# mel spec

self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = self.mel_spec.n_mel_channels

self.proj_in = nn.Linear(num_channels, dim)
self.to_pred = nn.Linear(dim, num_channels)

# immiscible (diffusion / flow)

Expand All @@ -496,6 +499,7 @@ def transformer_with_pred_head(
text: Int['b nt'] | None = None,
drop_text_cond: bool | None = None
):
x = self.proj_in(x)

if exists(text):
x = self.embed_text(x, text, drop_text_cond = drop_text_cond)
Expand Down

0 comments on commit d7f6741

Please sign in to comment.