Skip to content

Commit

Permalink
fix an issue with g2p and pydantic, and also add value residual learn…
Browse files Browse the repository at this point in the history
…ing paper from iclr 2025
  • Loading branch information
lucidrains committed Oct 31, 2024
1 parent 2825782 commit 88c659a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,12 @@ sampled = e2tts.sample(mel[:, :5], text = text)
url = {https://api.semanticscholar.org/CorpusID:267657558}
}
```

```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
14 changes: 12 additions & 2 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,11 @@ def forward(

skips = []

# value residual

text_attn_first_values = None
attn_first_values = None

# go through the layers

for ind, (speech_modules, text_modules) in enumerate(self.layers):
Expand Down Expand Up @@ -704,7 +709,10 @@ def forward(

text_embed = text_conv(text_embed, mask = mask) + text_embed

text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed
text_attn_out, text_attn_inter = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = text_attn_first_values)
text_embed = text_attn_out + text_embed

text_attn_first_values = default(text_attn_first_values, text_attn_inter.values)

text_embed = text_ff(text_ff_norm(text_embed)) + text_embed

Expand All @@ -729,7 +737,9 @@ def forward(

# attention and feedforward blocks

attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask)
attn_out, attn_inter = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = attn_first_values)

attn_first_values = default(attn_first_values, attn_inter.values)

x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs)

Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.2.4"
version = "1.4.0"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -32,13 +32,14 @@ dependencies = [
'g2p-en',
'jaxtyping',
'loguru',
'pydantic<2',
'tensorboard',
'torch>=2.0',
'torchdiffeq',
'torchaudio>=2.3.1',
'tqdm>=4.65.0',
'vocos',
'x-transformers>=1.31.14',
'x-transformers>=1.42.3',
]

[project.urls]
Expand Down

0 comments on commit 88c659a

Please sign in to comment.