Skip to content

Commit

Permalink
restore the condition when returning the output on .sample
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 23, 2024
1 parent d5f60e5 commit b0b00b7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
11 changes: 8 additions & 3 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ def __init__(
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = self.mel_spec.n_mel_channels

self.num_channels = num_channels

self.proj_in = nn.Linear(num_channels, dim)
self.to_pred = nn.Linear(dim, num_channels)

Expand Down Expand Up @@ -553,7 +555,7 @@ def sample(
if cond.ndim == 2:
cond = self.mel_spec(cond)
cond = rearrange(cond, 'b d n -> b n d')
assert cond.shape[-1] == self.dim
assert cond.shape[-1] == self.num_channels

batch, cond_seq_len, device = *cond.shape[:2], cond.device

Expand Down Expand Up @@ -583,6 +585,7 @@ def sample(

cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
cond_mask = F.pad(cond_mask, (0, max_duration - cond_seq_len), value = False)
cond_mask = rearrange(cond_mask, '... -> ... 1')

mask = lens_to_mask(duration)

Expand All @@ -591,7 +594,7 @@ def sample(
def fn(t, x):
# at each step, conditioning is fixed

x = torch.where(cond_mask[..., None], cond, x)
x = torch.where(cond_mask, cond, x)

# predict flow

Expand All @@ -611,6 +614,8 @@ def fn(t, x):

out = sampled

out = torch.where(cond_mask, cond, out)

if exists(vocoder):
out = rearrange(out, 'b n d -> b d n')
out = vocoder(out)
Expand All @@ -630,7 +635,7 @@ def forward(
if inp.ndim == 2:
inp = self.mel_spec(inp)
inp = rearrange(inp, 'b d n -> b n d')
assert inp.shape[-1] == self.dim
assert inp.shape[-1] == self.num_channels

batch, seq_len, dtype, device, σ = *inp.shape[:2], inp.dtype, self.device, self.sigma

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.0.41"
version = "0.0.42"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit b0b00b7

Please sign in to comment.