Skip to content

Commit

Permalink
some better transformer defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 14, 2024
1 parent 58288b2 commit b13f520
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ from e2_tts_pytorch import (
duration_predictor = DurationPredictor(
transformer = dict(
dim = 512,
depth = 2,
depth = 8,
)
)

Expand All @@ -44,7 +44,7 @@ e2tts = E2TTS(
duration_predictor = duration_predictor,
transformer = dict(
dim = 512,
depth = 4,
depth = 8,
skip_connect_type = 'concat'
),
)
Expand Down
5 changes: 3 additions & 2 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,12 @@ def __init__(
self,
*,
dim,
depth,
depth = 8,
cond_on_time = False,
skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
abs_pos_emb = True,
max_seq_len = 8192,
heads = 8,
dim_head = 64,
attn_kwargs: dict = dict(),
ff_kwargs: dict = dict()
Expand Down Expand Up @@ -231,7 +232,7 @@ def __init__(

for _ in range(depth):
attn_norm = rmsnorm_klass(dim)
attn = Attention(dim = dim, dim_head = dim_head, **attn_kwargs)
attn = Attention(dim = dim, heads = heads, dim_head = dim_head, **attn_kwargs)

ff_norm = rmsnorm_klass(dim)
ff = FeedForward(dim = dim, **ff_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.0.25"
version = "0.0.26"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit b13f520

Please sign in to comment.