Skip to content

Commit

Permalink
use film-like conditioning for text on audio embed, initialized to id…
Browse files Browse the repository at this point in the history
…entity
  • Loading branch information
lucidrains committed Jul 26, 2024
1 parent 92435be commit cb56ac6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
12 changes: 9 additions & 3 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,14 @@ def __init__(
):
super().__init__()
self.dim = dim
self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token'
self.combine = nn.Linear(dim * 2, dim)
self.cond_drop_prob = cond_drop_prob

self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token'
self.to_cond_gamma_beta = nn.Linear(dim * 2, dim * 2)

nn.init.zeros_(self.to_cond_gamma_beta.weight)
nn.init.zeros_(self.to_cond_gamma_beta.bias)

def forward(
self,
x: Float['b n d'],
Expand All @@ -211,7 +215,9 @@ def forward(

concatted = torch.cat((x, text_embed), dim = -1)
assert x.shape[-1] == text_embed.shape[-1] == self.dim, f'expected {self.dim} but received ({x.shape[-1]}, {text_embed.shape[-1]})'
return self.combine(concatted)

gamma, beta = self.to_cond_gamma_beta(concatted).chunk(2, dim = -1)
return x * (gamma + 1.) + beta

# attention and transformer backbone
# for use in both e2tts as well as duration module
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.1.8"
version = "0.1.9"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit cb56ac6

Please sign in to comment.