diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 05c8e28..9628e46 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -462,7 +462,6 @@ def __init__( dim = transformer.dim self.dim = dim - self.to_pred = nn.Linear(dim, dim) self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, cond_drop_prob = cond_drop_prob) @@ -479,6 +478,10 @@ def __init__( # mel spec self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = self.mel_spec.n_mel_channels + + self.proj_in = nn.Linear(num_channels, dim) + self.to_pred = nn.Linear(dim, num_channels) # immiscible (diffusion / flow) @@ -496,6 +499,7 @@ def transformer_with_pred_head( text: Int['b nt'] | None = None, drop_text_cond: bool | None = None ): + x = self.proj_in(x) if exists(text): x = self.embed_text(x, text, drop_text_cond = drop_text_cond)