From 96e4f72bb7194c7fe0493315e9e172bad7ee7842 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 21 Jul 2024 17:27:30 -0700 Subject: [PATCH] fix duration predictor as well, also make sure immiscible diffusion works with gpu --- README.md | 2 +- e2_tts_pytorch/e2_tts.py | 20 +++++++++++++------- pyproject.toml | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index fb437f2..75360d8 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ duration_predictor = DurationPredictor( ) ) -mel = torch.randn(2, 1024, 512) +mel = torch.randn(2, 1024, 80) text = ['Hello', 'Goodbye'] loss = duration_predictor(mel, text = text) diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 9628e46..bda1f63 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -12,7 +12,7 @@ from random import random import torch -from torch import nn +from torch import nn, from_numpy import torch.nn.functional as F from torch.nn import Module, ModuleList from torch.nn.utils.rnn import pad_sequence @@ -355,10 +355,17 @@ def __init__( cond_on_time = False ) + # mel spec + + self.mel_spec = MelSpec(**mel_spec_kwargs) + self.num_channels = self.mel_spec.n_mel_channels + self.transformer = transformer dim = transformer.dim - self.dim = dim + + self.proj_in = nn.Linear(self.num_channels, self.dim) + self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds) self.to_pred = nn.Sequential( @@ -367,10 +374,6 @@ def __init__( Rearrange('... 1 -> ...') ) - # mel spec - - self.mel_spec = MelSpec(**mel_spec_kwargs) - def forward( self, x: Float['b n d'] | Float['b nw'], @@ -386,6 +389,8 @@ def forward( x = rearrange(x, 'b d n -> b n d') assert x.shape[-1] == self.dim + x = self.proj_in(x) + batch, seq_len, device = *x.shape[:2], x.device # text @@ -670,7 +675,8 @@ def forward( if self.immiscible: cost = torch.cdist(x1.flatten(1), x0.flatten(1)) - _, reorder_indices = linear_sum_assignment(cost) + _, reorder_indices = linear_sum_assignment(cost.cpu()) + reorder_indices = from_numpy(reorder_indices).to(cost.device) x0 = x0[reorder_indices] # t is random times from above diff --git a/pyproject.toml b/pyproject.toml index d1ebfd9..1d3e976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.0.38" +version = "0.0.39" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }