diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 65c4d7f..aa7cfaf 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -439,7 +439,7 @@ def forward( text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch - x = self.embed_text(x, text) + x = self.embed_text(x, x, text) # handle lengths (duration) @@ -623,7 +623,7 @@ def sample( duration = torch.full((batch,), duration, device = device, dtype = torch.long) elif exists(self.duration_predictor): - duration = self.duration_predictor(cond, lens = lens).long() + duration = self.duration_predictor(cond, text = text, lens = lens).long() duration = torch.maximum(lens + 1, duration) # just add one token so something is generated duration = duration.clamp(max = max_duration) diff --git a/pyproject.toml b/pyproject.toml index 084594b..5ba1dba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.1.10" +version = "0.2.0" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }