Skip to content

Commit

Permalink
add learned residual scaling for main transformer path
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 19, 2024
1 parent fc3af67 commit b3386f8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
31 changes: 28 additions & 3 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def __init__(
kernel_size = 31,
dropout = 0.1,
num_registers = 32,
scale_residual = False,
attn_laser = False,
attn_laser_softclamp_value = 15.,
attn_kwargs: dict = dict(
Expand Down Expand Up @@ -540,6 +541,10 @@ def __init__(
self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text))
nn.init.normal_(self.text_registers, std = 0.02)

# maybe residual scales

residual_scales = []

# rotary embedding

self.rotary_emb = RotaryEmbedding(dim_head)
Expand Down Expand Up @@ -575,10 +580,16 @@ def __init__(
attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, laser_softclamp_value = attn_laser_softclamp_value, **attn_kwargs)
attn_adaln_zero = postbranch_klass()


ff_norm = rmsnorm_klass(dim)
ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs)
ff_adaln_zero = postbranch_klass()

residual_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim)),
nn.Parameter(torch.ones(dim))
]))

skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None

speech_modules = ModuleList([
Expand Down Expand Up @@ -625,6 +636,8 @@ def __init__(
text_modules
]))

self.residual_scales = nn.ParameterList(residual_scales) if scale_residual else None

self.final_norm = RMSNorm(dim)

def forward(
Expand Down Expand Up @@ -685,9 +698,14 @@ def forward(
text_attn_first_values = None
attn_first_values = None

# prepare residual scales

residual_scales = default(self.residual_scales, (None,) * len(self.layers))

# go through the layers

for ind, (speech_modules, text_modules) in enumerate(self.layers):
for ind, ((speech_modules, text_modules), maybe_residual_scales) in enumerate(zip(self.layers, residual_scales)):

layer = ind + 1

(
Expand Down Expand Up @@ -742,17 +760,24 @@ def forward(

x = speech_conv(x, mask = mask) + x

# maybe residual scaling

attn_res_scale, ff_res_scale = 1., 1.

if exists(maybe_residual_scales):
attn_res_scale, ff_res_scale = maybe_residual_scales

# attention and feedforward blocks

attn_out, attn_inter = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = attn_first_values)

attn_first_values = default(attn_first_values, attn_inter.values)

x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs)
x = x * attn_res_scale + maybe_attn_adaln_zero(attn_out, **norm_kwargs)

ff_out = ff(ff_norm(x, **norm_kwargs))

x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs)
x = x * ff_res_scale + maybe_ff_adaln_zero(ff_out, **norm_kwargs)

assert len(skips) == 0

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.6.2"
version = "1.6.3"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit b3386f8

Please sign in to comment.