Skip to content

Commit

Permalink
fix duration predictor as well, also make sure immiscible diffusion w…
Browse files Browse the repository at this point in the history
…orks with gpu
  • Loading branch information
lucidrains committed Jul 22, 2024
1 parent 93a0cc7 commit 96e4f72
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ duration_predictor = DurationPredictor(
)
)

mel = torch.randn(2, 1024, 512)
mel = torch.randn(2, 1024, 80)
text = ['Hello', 'Goodbye']

loss = duration_predictor(mel, text = text)
Expand Down
20 changes: 13 additions & 7 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from random import random

import torch
from torch import nn
from torch import nn, from_numpy
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -355,10 +355,17 @@ def __init__(
cond_on_time = False
)

# mel spec

self.mel_spec = MelSpec(**mel_spec_kwargs)
self.num_channels = self.mel_spec.n_mel_channels

self.transformer = transformer
dim = transformer.dim

self.dim = dim

self.proj_in = nn.Linear(self.num_channels, self.dim)

self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds)

self.to_pred = nn.Sequential(
Expand All @@ -367,10 +374,6 @@ def __init__(
Rearrange('... 1 -> ...')
)

# mel spec

self.mel_spec = MelSpec(**mel_spec_kwargs)

def forward(
self,
x: Float['b n d'] | Float['b nw'],
Expand All @@ -386,6 +389,8 @@ def forward(
x = rearrange(x, 'b d n -> b n d')
assert x.shape[-1] == self.dim

x = self.proj_in(x)

batch, seq_len, device = *x.shape[:2], x.device

# text
Expand Down Expand Up @@ -670,7 +675,8 @@ def forward(

if self.immiscible:
cost = torch.cdist(x1.flatten(1), x0.flatten(1))
_, reorder_indices = linear_sum_assignment(cost)
_, reorder_indices = linear_sum_assignment(cost.cpu())
reorder_indices = from_numpy(reorder_indices).to(cost.device)
x0 = x0[reorder_indices]

# t is random times from above
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.0.38"
version = "0.0.39"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 96e4f72

Please sign in to comment.