diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 6f83fcb..6b85aa7 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -33,6 +33,8 @@ AdaptiveRMSNorm ) +from x_transformers.x_transformers import RotaryEmbedding + from e2_tts_pytorch.tensor_typing import ( Float, Int, @@ -189,6 +191,7 @@ def __init__( skip_connect_type: Literal['add', 'concat', 'none'] = 'concat', abs_pos_emb = True, max_seq_len = 8192, + dim_head = 64, attn_kwargs: dict = dict(), ff_kwargs: dict = dict() ): @@ -207,6 +210,10 @@ def __init__( self.depth = depth self.layers = ModuleList([]) + # rotary embedding + + self.rotary_emb = RotaryEmbedding(dim_head) + # time conditioning # will use adaptive rmsnorm @@ -224,7 +231,7 @@ def __init__( for _ in range(depth): attn_norm = rmsnorm_klass(dim) - attn = Attention(dim = dim, **attn_kwargs) + attn = Attention(dim = dim, dim_head = dim_head, **attn_kwargs) ff_norm = rmsnorm_klass(dim) ff = FeedForward(dim = dim, **ff_kwargs) @@ -269,6 +276,10 @@ def forward( times = self.time_cond_mlp(times) norm_kwargs.update(condition = times) + # rotary embedding + + rotary_pos_emb = self.rotary_emb.forward_from_seq_len(seq_len) + # skip connection related stuff skip_connect_type = self.skip_connect_type @@ -301,7 +312,7 @@ def forward( # attention and feedforward blocks - x = attn(attn_norm(x, **norm_kwargs)) + x + x = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask) + x x = ff(ff_norm(x, **norm_kwargs)) + x assert len(skips) == 0 diff --git a/pyproject.toml b/pyproject.toml index 960a885..503e589 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.0.24" +version = "0.0.25" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }