From cb56ac6cb63061ffb50c15d7cf1abba1307de92a Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 25 Jul 2024 17:57:59 -0700 Subject: [PATCH] use film-like conditioning for text on audio embed, initialized to identity --- e2_tts_pytorch/e2_tts.py | 12 +++++++++--- pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index ddebbfb..cb25526 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -186,10 +186,14 @@ def __init__( ): super().__init__() self.dim = dim - self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token' - self.combine = nn.Linear(dim * 2, dim) self.cond_drop_prob = cond_drop_prob + self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token' + self.to_cond_gamma_beta = nn.Linear(dim * 2, dim * 2) + + nn.init.zeros_(self.to_cond_gamma_beta.weight) + nn.init.zeros_(self.to_cond_gamma_beta.bias) + def forward( self, x: Float['b n d'], @@ -211,7 +215,9 @@ def forward( concatted = torch.cat((x, text_embed), dim = -1) assert x.shape[-1] == text_embed.shape[-1] == self.dim, f'expected {self.dim} but received ({x.shape[-1]}, {text_embed.shape[-1]})' - return self.combine(concatted) + + gamma, beta = self.to_cond_gamma_beta(concatted).chunk(2, dim = -1) + return x * (gamma + 1.) + beta # attention and transformer backbone # for use in both e2tts as well as duration module diff --git a/pyproject.toml b/pyproject.toml index 0d93ac2..fe2d17e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.1.8" +version = "0.1.9" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }