From 88c659a4af264ef3fc90c7a4680c15792e549a3c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 31 Oct 2024 07:25:13 -0700 Subject: [PATCH] fix an issue with g2p and pydantic, and also add value residual learning paper from iclr 2025 --- README.md | 9 +++++++++ e2_tts_pytorch/e2_tts.py | 14 ++++++++++++-- pyproject.toml | 5 +++-- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b2fe95d..4ba5310 100644 --- a/README.md +++ b/README.md @@ -146,3 +146,12 @@ sampled = e2tts.sample(mel[:, :5], text = text) url = {https://api.semanticscholar.org/CorpusID:267657558} } ``` + +```bibtex +@inproceedings{Zhou2024ValueRL, + title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, + author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273532030} +} +``` diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 91e3971..5a6bc3e 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -673,6 +673,11 @@ def forward( skips = [] + # value residual + + text_attn_first_values = None + attn_first_values = None + # go through the layers for ind, (speech_modules, text_modules) in enumerate(self.layers): @@ -704,7 +709,10 @@ def forward( text_embed = text_conv(text_embed, mask = mask) + text_embed - text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed + text_attn_out, text_attn_inter = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = text_attn_first_values) + text_embed = text_attn_out + text_embed + + text_attn_first_values = default(text_attn_first_values, text_attn_inter.values) text_embed = text_ff(text_ff_norm(text_embed)) + text_embed @@ -729,7 +737,9 @@ def forward( # attention and feedforward blocks - attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask) + attn_out, attn_inter = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = attn_first_values) + + attn_first_values = default(attn_first_values, attn_inter.values) x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs) diff --git a/pyproject.toml b/pyproject.toml index f20a644..10d099b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "1.2.4" +version = "1.4.0" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -32,13 +32,14 @@ dependencies = [ 'g2p-en', 'jaxtyping', 'loguru', + 'pydantic<2', 'tensorboard', 'torch>=2.0', 'torchdiffeq', 'torchaudio>=2.3.1', 'tqdm>=4.65.0', 'vocos', - 'x-transformers>=1.31.14', + 'x-transformers>=1.42.3', ] [project.urls]