From f315bd20324c001813f92d365dd82ea7a2412df6 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 23 Jul 2024 10:10:34 -0700 Subject: [PATCH] readme and cleanup --- README.md | 2 +- e2_tts_pytorch/e2_tts.py | 4 +--- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 75360d8..4c53802 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ duration_predictor = DurationPredictor( ) ) -mel = torch.randn(2, 1024, 80) +mel = torch.randn(2, 1024, 100) text = ['Hello', 'Goodbye'] loss = duration_predictor(mel, text = text) diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index bd588f7..59d57cb 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -168,10 +168,8 @@ def forward( return x max_seq_len = x.shape[1] - text_mask = text == -1 - text = text + 1 # use 0 as filler token - text = text.masked_fill(text_mask, 0) + text = text + 1 # shift all other token ids up by 1 and use 0 as filler token text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address text = F.pad(text, (0, max_seq_len - text.shape[1]), value = 0) diff --git a/pyproject.toml b/pyproject.toml index 06a10b0..f894ba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.1.0" +version = "0.1.1" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }