diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index bda1f63..38fe1a8 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -485,6 +485,8 @@ def __init__( self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) num_channels = self.mel_spec.n_mel_channels + self.num_channels = num_channels + self.proj_in = nn.Linear(num_channels, dim) self.to_pred = nn.Linear(dim, num_channels) @@ -553,7 +555,7 @@ def sample( if cond.ndim == 2: cond = self.mel_spec(cond) cond = rearrange(cond, 'b d n -> b n d') - assert cond.shape[-1] == self.dim + assert cond.shape[-1] == self.num_channels batch, cond_seq_len, device = *cond.shape[:2], cond.device @@ -583,6 +585,7 @@ def sample( cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.) cond_mask = F.pad(cond_mask, (0, max_duration - cond_seq_len), value = False) + cond_mask = rearrange(cond_mask, '... -> ... 1') mask = lens_to_mask(duration) @@ -591,7 +594,7 @@ def sample( def fn(t, x): # at each step, conditioning is fixed - x = torch.where(cond_mask[..., None], cond, x) + x = torch.where(cond_mask, cond, x) # predict flow @@ -611,6 +614,8 @@ def fn(t, x): out = sampled + out = torch.where(cond_mask, cond, out) + if exists(vocoder): out = rearrange(out, 'b n d -> b d n') out = vocoder(out) @@ -630,7 +635,7 @@ def forward( if inp.ndim == 2: inp = self.mel_spec(inp) inp = rearrange(inp, 'b d n -> b n d') - assert inp.shape[-1] == self.dim + assert inp.shape[-1] == self.num_channels batch, seq_len, dtype, device, σ = *inp.shape[:2], inp.dtype, self.device, self.sigma diff --git a/pyproject.toml b/pyproject.toml index a06ed9d..622289d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.0.41" +version = "0.0.42" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }