diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 97b93fb..22d2b20 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -461,7 +461,6 @@ def __init__( dim = transformer.dim self.dim = dim - self.to_pred = nn.Linear(dim, dim) self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, cond_drop_prob = cond_drop_prob) @@ -478,6 +477,10 @@ def __init__( # mel spec self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = self.mel_spec.n_mel_channels + + self.proj_in = nn.Linear(num_channels, dim) + self.to_pred = nn.Linear(dim, num_channels) # immiscible (diffusion / flow) @@ -495,6 +498,7 @@ def transformer_with_pred_head( text: Int['b nt'] | None = None, drop_text_cond: bool | None = None ): + x = self.proj_in(x) if exists(text): x = self.embed_text(x, text, drop_text_cond = drop_text_cond)