From b13f520271dad14ea9159aaae5b3ccf2de45fc0c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 14 Jul 2024 15:04:55 -0700 Subject: [PATCH] some better transformer defaults --- README.md | 4 ++-- e2_tts_pytorch/e2_tts.py | 5 +++-- pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 1028191..d0d4f48 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ from e2_tts_pytorch import ( duration_predictor = DurationPredictor( transformer = dict( dim = 512, - depth = 2, + depth = 8, ) ) @@ -44,7 +44,7 @@ e2tts = E2TTS( duration_predictor = duration_predictor, transformer = dict( dim = 512, - depth = 4, + depth = 8, skip_connect_type = 'concat' ), ) diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 6b85aa7..0af3ee3 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -186,11 +186,12 @@ def __init__( self, *, dim, - depth, + depth = 8, cond_on_time = False, skip_connect_type: Literal['add', 'concat', 'none'] = 'concat', abs_pos_emb = True, max_seq_len = 8192, + heads = 8, dim_head = 64, attn_kwargs: dict = dict(), ff_kwargs: dict = dict() @@ -231,7 +232,7 @@ def __init__( for _ in range(depth): attn_norm = rmsnorm_klass(dim) - attn = Attention(dim = dim, dim_head = dim_head, **attn_kwargs) + attn = Attention(dim = dim, heads = heads, dim_head = dim_head, **attn_kwargs) ff_norm = rmsnorm_klass(dim) ff = FeedForward(dim = dim, **ff_kwargs) diff --git a/pyproject.toml b/pyproject.toml index 503e589..4ec5d78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.0.25" +version = "0.0.26" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }