From dab09399d5205e631ed813ffd34a1c5f24a859dc Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 26 Dec 2024 07:43:46 -0800 Subject: [PATCH] add hyper connections and cite, default to 4 residual streams --- README.md | 11 +++++ e2_tts_pytorch/e2_tts.py | 99 ++++++++++++++++++++++++++++------------ pyproject.toml | 3 +- 3 files changed, 84 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 4337f5e..6607c2f 100644 --- a/README.md +++ b/README.md @@ -164,3 +164,14 @@ sampled = e2tts.sample(mel[:, :5], text = text) url = {https://api.semanticscholar.org/CorpusID:273849947} } ``` + +```bibtex +@article{Zhu2024HyperConnections, + title = {Hyper-Connections}, + author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, + journal = {ArXiv}, + year = {2024}, + volume = {abs/2409.19606}, + url = {https://api.semanticscholar.org/CorpusID:272987528} +} +``` diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index d1b1062..f4b8122 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -44,6 +44,8 @@ from x_transformers.x_transformers import RotaryEmbedding +from hyper_connections import HyperConnections + from vocos import Vocos pad_sequence = partial(pad_sequence, batch_first = True) @@ -503,6 +505,7 @@ def __init__( scale_residual = False, attn_laser = False, attn_laser_softclamp_value = 15., + num_residual_streams = 4, attn_kwargs: dict = dict( gate_value_heads = True, softclamp_logits = True, @@ -530,7 +533,7 @@ def __init__( assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers' self.depth = depth - self.layers = ModuleList([]) + layers = [] # registers @@ -550,6 +553,12 @@ def __init__( self.rotary_emb = RotaryEmbedding(dim_head) self.text_rotary_emb = RotaryEmbedding(dim_head) + # hyper connection related + + init_hyper_conn, self.hyper_conn_expand, self.hyper_conn_reduce = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) + + hyper_conns = [] + # time conditioning # will use adaptive rmsnorm @@ -580,16 +589,10 @@ def __init__( attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, laser_softclamp_value = attn_laser_softclamp_value, **attn_kwargs) attn_adaln_zero = postbranch_klass() - ff_norm = rmsnorm_klass(dim) ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs) ff_adaln_zero = postbranch_klass() - residual_scales.append(nn.ParameterList([ - nn.Parameter(torch.ones(dim)), - nn.Parameter(torch.ones(dim)) - ])) - skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None speech_modules = ModuleList([ @@ -603,7 +606,14 @@ def __init__( ff_adaln_zero, ]) + speech_hyper_conns = ModuleList([ + init_hyper_conn(dim = dim), # conv + init_hyper_conn(dim = dim), # attn + init_hyper_conn(dim = dim), # ff + ]) + text_modules = None + text_hyper_conns = None if has_text: # text related @@ -631,12 +641,25 @@ def __init__( cross_condition ]) - self.layers.append(ModuleList([ + text_hyper_conns = ModuleList([ + init_hyper_conn(dim = dim_text), # conv + init_hyper_conn(dim = dim_text), # attn + init_hyper_conn(dim = dim_text), # ff + ]) + + hyper_conns.append(ModuleList([ + speech_hyper_conns, + text_hyper_conns + ])) + + layers.append(ModuleList([ speech_modules, text_modules ])) - self.residual_scales = nn.ParameterList(residual_scales) if scale_residual else None + self.layers = ModuleList(layers) + + self.hyper_conns = ModuleList(hyper_conns) self.final_norm = RMSNorm(dim) @@ -698,13 +721,16 @@ def forward( text_attn_first_values = None attn_first_values = None - # prepare residual scales + # expand hyper connections - residual_scales = default(self.residual_scales, (None,) * len(self.layers)) + x = self.hyper_conn_expand(x) + + if exists(text_embed): + text_embed = self.hyper_conn_expand(text_embed) # go through the layers - for ind, ((speech_modules, text_modules), maybe_residual_scales) in enumerate(zip(self.layers, residual_scales)): + for ind, ((speech_modules, text_modules), (speech_residual_fns, text_residual_fns)) in enumerate(zip(self.layers, self.hyper_conns)): layer = ind + 1 @@ -719,6 +745,12 @@ def forward( maybe_ff_adaln_zero ) = speech_modules + ( + conv_residual, + attn_residual, + ff_residual + ) = speech_residual_fns + # smaller text transformer if exists(text_embed) and exists(text_modules): @@ -732,15 +764,25 @@ def forward( cross_condition ) = text_modules - text_embed = text_conv(text_embed, mask = mask) + text_embed + ( + text_conv_residual, + text_attn_residual, + text_ff_residual + ) = text_residual_fns + + text_embed, add_residual = text_conv_residual(text_embed) + text_embed = text_conv(text_embed, mask = mask) + text_embed = add_residual(text_embed) + text_embed, add_residual = text_attn_residual(text_embed) text_attn_out, text_attn_inter = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = text_attn_first_values) - text_embed = text_attn_out + text_embed + text_embed = add_residual(text_attn_out) text_attn_first_values = default(text_attn_first_values, text_attn_inter.values) - text_embed = text_ff(text_ff_norm(text_embed)) + text_embed - + text_embed, add_residual = text_ff_residual(text_embed) + text_embed = text_ff(text_ff_norm(text_embed)) + text_embed = add_residual(text_embed) x, text_embed = cross_condition(x, text_embed) # skip connection logic @@ -758,31 +800,32 @@ def forward( # position generating convolution - x = speech_conv(x, mask = mask) + x - - # maybe residual scaling - - attn_res_scale, ff_res_scale = 1., 1. - - if exists(maybe_residual_scales): - attn_res_scale, ff_res_scale = maybe_residual_scales + x, add_residual = conv_residual(x) + x = speech_conv(x, mask = mask) + x = add_residual(x) # attention and feedforward blocks + x, add_residual = attn_residual(x) attn_out, attn_inter = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = attn_first_values) + attn_out = maybe_attn_adaln_zero(attn_out, **norm_kwargs) + x = add_residual(attn_out) attn_first_values = default(attn_first_values, attn_inter.values) - x = x * attn_res_scale + maybe_attn_adaln_zero(attn_out, **norm_kwargs) - + x, add_residual = ff_residual(x) ff_out = ff(ff_norm(x, **norm_kwargs)) - - x = x * ff_res_scale + maybe_ff_adaln_zero(ff_out, **norm_kwargs) + ff_out = maybe_ff_adaln_zero(ff_out, **norm_kwargs) + x = add_residual(ff_out) assert len(skips) == 0 _, x = unpack(x, registers_packed_shape, 'b * d') + # sum all residual streams from hyper connections + + x = self.hyper_conn_reduce(x) + return self.final_norm(x) # main classes diff --git a/pyproject.toml b/pyproject.toml index b58f41b..fac4a98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "1.6.3" +version = "1.7.1" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -30,6 +30,7 @@ dependencies = [ 'einops>=0.8.0', 'einx>=0.3.0', 'ema-pytorch>=0.5.2', + 'hyper-connections>=0.0.10', 'g2p-en', 'jaxtyping', 'loguru',