Skip to content

Commit

Permalink
add hyper connections and cite, default to 4 residual streams
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 26, 2024
1 parent b3386f8 commit dab0939
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 29 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,14 @@ sampled = e2tts.sample(mel[:, :5], text = text)
url = {https://api.semanticscholar.org/CorpusID:273849947}
}
```

```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```
99 changes: 71 additions & 28 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

from x_transformers.x_transformers import RotaryEmbedding

from hyper_connections import HyperConnections

from vocos import Vocos

pad_sequence = partial(pad_sequence, batch_first = True)
Expand Down Expand Up @@ -503,6 +505,7 @@ def __init__(
scale_residual = False,
attn_laser = False,
attn_laser_softclamp_value = 15.,
num_residual_streams = 4,
attn_kwargs: dict = dict(
gate_value_heads = True,
softclamp_logits = True,
Expand Down Expand Up @@ -530,7 +533,7 @@ def __init__(
assert 1 <= text_depth <= depth, 'must have at least 1 layer of text conditioning, but less than total number of speech layers'

self.depth = depth
self.layers = ModuleList([])
layers = []

# registers

Expand All @@ -550,6 +553,12 @@ def __init__(
self.rotary_emb = RotaryEmbedding(dim_head)
self.text_rotary_emb = RotaryEmbedding(dim_head)

# hyper connection related

init_hyper_conn, self.hyper_conn_expand, self.hyper_conn_reduce = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

hyper_conns = []

# time conditioning
# will use adaptive rmsnorm

Expand Down Expand Up @@ -580,16 +589,10 @@ def __init__(
attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, laser = attn_laser, laser_softclamp_value = attn_laser_softclamp_value, **attn_kwargs)
attn_adaln_zero = postbranch_klass()


ff_norm = rmsnorm_klass(dim)
ff = FeedForward(dim = dim, glu = True, mult = ff_mult, dropout = dropout, **ff_kwargs)
ff_adaln_zero = postbranch_klass()

residual_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim)),
nn.Parameter(torch.ones(dim))
]))

skip_proj = Linear(dim * 2, dim, bias = False) if is_later_half else None

speech_modules = ModuleList([
Expand All @@ -603,7 +606,14 @@ def __init__(
ff_adaln_zero,
])

speech_hyper_conns = ModuleList([
init_hyper_conn(dim = dim), # conv
init_hyper_conn(dim = dim), # attn
init_hyper_conn(dim = dim), # ff
])

text_modules = None
text_hyper_conns = None

if has_text:
# text related
Expand Down Expand Up @@ -631,12 +641,25 @@ def __init__(
cross_condition
])

self.layers.append(ModuleList([
text_hyper_conns = ModuleList([
init_hyper_conn(dim = dim_text), # conv
init_hyper_conn(dim = dim_text), # attn
init_hyper_conn(dim = dim_text), # ff
])

hyper_conns.append(ModuleList([
speech_hyper_conns,
text_hyper_conns
]))

layers.append(ModuleList([
speech_modules,
text_modules
]))

self.residual_scales = nn.ParameterList(residual_scales) if scale_residual else None
self.layers = ModuleList(layers)

self.hyper_conns = ModuleList(hyper_conns)

self.final_norm = RMSNorm(dim)

Expand Down Expand Up @@ -698,13 +721,16 @@ def forward(
text_attn_first_values = None
attn_first_values = None

# prepare residual scales
# expand hyper connections

residual_scales = default(self.residual_scales, (None,) * len(self.layers))
x = self.hyper_conn_expand(x)

if exists(text_embed):
text_embed = self.hyper_conn_expand(text_embed)

# go through the layers

for ind, ((speech_modules, text_modules), maybe_residual_scales) in enumerate(zip(self.layers, residual_scales)):
for ind, ((speech_modules, text_modules), (speech_residual_fns, text_residual_fns)) in enumerate(zip(self.layers, self.hyper_conns)):

layer = ind + 1

Expand All @@ -719,6 +745,12 @@ def forward(
maybe_ff_adaln_zero
) = speech_modules

(
conv_residual,
attn_residual,
ff_residual
) = speech_residual_fns

# smaller text transformer

if exists(text_embed) and exists(text_modules):
Expand All @@ -732,15 +764,25 @@ def forward(
cross_condition
) = text_modules

text_embed = text_conv(text_embed, mask = mask) + text_embed
(
text_conv_residual,
text_attn_residual,
text_ff_residual
) = text_residual_fns

text_embed, add_residual = text_conv_residual(text_embed)
text_embed = text_conv(text_embed, mask = mask)
text_embed = add_residual(text_embed)

text_embed, add_residual = text_attn_residual(text_embed)
text_attn_out, text_attn_inter = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = text_attn_first_values)
text_embed = text_attn_out + text_embed
text_embed = add_residual(text_attn_out)

text_attn_first_values = default(text_attn_first_values, text_attn_inter.values)

text_embed = text_ff(text_ff_norm(text_embed)) + text_embed

text_embed, add_residual = text_ff_residual(text_embed)
text_embed = text_ff(text_ff_norm(text_embed))
text_embed = add_residual(text_embed)
x, text_embed = cross_condition(x, text_embed)

# skip connection logic
Expand All @@ -758,31 +800,32 @@ def forward(

# position generating convolution

x = speech_conv(x, mask = mask) + x

# maybe residual scaling

attn_res_scale, ff_res_scale = 1., 1.

if exists(maybe_residual_scales):
attn_res_scale, ff_res_scale = maybe_residual_scales
x, add_residual = conv_residual(x)
x = speech_conv(x, mask = mask)
x = add_residual(x)

# attention and feedforward blocks

x, add_residual = attn_residual(x)
attn_out, attn_inter = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask, return_intermediates = True, value_residual = attn_first_values)
attn_out = maybe_attn_adaln_zero(attn_out, **norm_kwargs)
x = add_residual(attn_out)

attn_first_values = default(attn_first_values, attn_inter.values)

x = x * attn_res_scale + maybe_attn_adaln_zero(attn_out, **norm_kwargs)

x, add_residual = ff_residual(x)
ff_out = ff(ff_norm(x, **norm_kwargs))

x = x * ff_res_scale + maybe_ff_adaln_zero(ff_out, **norm_kwargs)
ff_out = maybe_ff_adaln_zero(ff_out, **norm_kwargs)
x = add_residual(ff_out)

assert len(skips) == 0

_, x = unpack(x, registers_packed_shape, 'b * d')

# sum all residual streams from hyper connections

x = self.hyper_conn_reduce(x)

return self.final_norm(x)

# main classes
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.6.3"
version = "1.7.1"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -30,6 +30,7 @@ dependencies = [
'einops>=0.8.0',
'einx>=0.3.0',
'ema-pytorch>=0.5.2',
'hyper-connections>=0.0.10',
'g2p-en',
'jaxtyping',
'loguru',
Expand Down

0 comments on commit dab0939

Please sign in to comment.