Skip to content

Commit

Permalink
rotary embedding and fix self attn mask
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 14, 2024
1 parent 52129b9 commit 58288b2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
15 changes: 13 additions & 2 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
AdaptiveRMSNorm
)

from x_transformers.x_transformers import RotaryEmbedding

from e2_tts_pytorch.tensor_typing import (
Float,
Int,
Expand Down Expand Up @@ -189,6 +191,7 @@ def __init__(
skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
abs_pos_emb = True,
max_seq_len = 8192,
dim_head = 64,
attn_kwargs: dict = dict(),
ff_kwargs: dict = dict()
):
Expand All @@ -207,6 +210,10 @@ def __init__(
self.depth = depth
self.layers = ModuleList([])

# rotary embedding

self.rotary_emb = RotaryEmbedding(dim_head)

# time conditioning
# will use adaptive rmsnorm

Expand All @@ -224,7 +231,7 @@ def __init__(

for _ in range(depth):
attn_norm = rmsnorm_klass(dim)
attn = Attention(dim = dim, **attn_kwargs)
attn = Attention(dim = dim, dim_head = dim_head, **attn_kwargs)

ff_norm = rmsnorm_klass(dim)
ff = FeedForward(dim = dim, **ff_kwargs)
Expand Down Expand Up @@ -269,6 +276,10 @@ def forward(
times = self.time_cond_mlp(times)
norm_kwargs.update(condition = times)

# rotary embedding

rotary_pos_emb = self.rotary_emb.forward_from_seq_len(seq_len)

# skip connection related stuff

skip_connect_type = self.skip_connect_type
Expand Down Expand Up @@ -301,7 +312,7 @@ def forward(

# attention and feedforward blocks

x = attn(attn_norm(x, **norm_kwargs)) + x
x = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask) + x
x = ff(ff_norm(x, **norm_kwargs)) + x

assert len(skips) == 0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.0.24"
version = "0.0.25"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 58288b2

Please sign in to comment.