Skip to content

Commit

Permalink
laser attention appears to work well
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 2, 2024
1 parent 70af540 commit f8eca56
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,12 @@ sampled = e2tts.sample(mel[:, :5], text = text)
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

```bibtex
@inproceedings{Duvvuri2024LASERAW,
title = {LASER: Attention with Exponential Transformation},
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273849947}
}
```
5 changes: 3 additions & 2 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def __init__(
kernel_size = 31,
dropout = 0.1,
num_registers = 32,
attn_laser = False,
attn_kwargs: dict = dict(
gate_value_heads = True,
softclamp_logits = True,
Expand Down Expand Up @@ -570,7 +571,7 @@ def __init__(
speech_conv = DepthwiseConv(dim, kernel_size = kernel_size)

attn_norm = rmsnorm_klass(dim)
attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs)
attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, **attn_kwargs)
attn_adaln_zero = postbranch_klass()

ff_norm = rmsnorm_klass(dim)
Expand Down Expand Up @@ -598,7 +599,7 @@ def __init__(
text_conv = DepthwiseConv(dim_text, kernel_size = kernel_size)

text_attn_norm = RMSNorm(dim_text)
text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, **attn_kwargs)
text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, **attn_kwargs)

text_ff_norm = RMSNorm(dim_text)
text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.6.0"
version = "1.6.1"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -40,7 +40,7 @@ dependencies = [
'torchaudio>=2.3.1',
'tqdm>=4.65.0',
'vocos',
'x-transformers>=1.42.16',
'x-transformers>=1.42.22',
]

[project.urls]
Expand Down

0 comments on commit f8eca56

Please sign in to comment.