diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index ed0489f..d1b1062 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -500,6 +500,7 @@ def __init__( kernel_size = 31, dropout = 0.1, num_registers = 32, + scale_residual = False, attn_laser = False, attn_laser_softclamp_value = 15., attn_kwargs: dict = dict( @@ -540,6 +541,10 @@ def __init__( self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text)) nn.init.normal_(self.text_registers, std = 0.02) + # maybe residual scales + + residual_scales = [] + # rotary embedding self.rotary_emb = RotaryEmbedding(dim_head) @@ -575,10 +580,16 @@ def __init__( attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, laser_softclamp_value = attn_laser_softclamp_value, **attn_kwargs) attn_adaln_zero = postbranch_klass() + ff_norm = rmsnorm_klass(dim) ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs) ff_adaln_zero = postbranch_klass() + residual_scales.append(nn.ParameterList([ + nn.Parameter(torch.ones(dim)), + nn.Parameter(torch.ones(dim)) + ])) + skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None speech_modules = ModuleList([ @@ -625,6 +636,8 @@ def __init__( text_modules ])) + self.residual_scales = nn.ParameterList(residual_scales) if scale_residual else None + self.final_norm = RMSNorm(dim) def forward( @@ -685,9 +698,14 @@ def forward( text_attn_first_values = None attn_first_values = None + # prepare residual scales + + residual_scales = default(self.residual_scales, (None,) * len(self.layers)) + # go through the layers - for ind, (speech_modules, text_modules) in enumerate(self.layers): + for ind, ((speech_modules, text_modules), maybe_residual_scales) in enumerate(zip(self.layers, residual_scales)): + layer = ind + 1 ( @@ -742,17 +760,24 @@ def forward( x = speech_conv(x, mask = mask) + x + # maybe residual scaling + + attn_res_scale, ff_res_scale = 1., 1. + + if exists(maybe_residual_scales): + attn_res_scale, ff_res_scale = maybe_residual_scales + # attention and feedforward blocks attn_out, attn_inter = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = attn_first_values) attn_first_values = default(attn_first_values, attn_inter.values) - x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs) + x = x * attn_res_scale + maybe_attn_adaln_zero(attn_out, **norm_kwargs) ff_out = ff(ff_norm(x, **norm_kwargs)) - x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs) + x = x * ff_res_scale + maybe_ff_adaln_zero(ff_out, **norm_kwargs) assert len(skips) == 0 diff --git a/pyproject.toml b/pyproject.toml index 44f3c7b..b58f41b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "1.6.2" +version = "1.6.3" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }