From 6caa246b078983189c46bdb3035e66dd8aa556ef Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 11 Sep 2024 14:27:00 -0700 Subject: [PATCH] address https://github.com/lucidrains/soundstorm-pytorch/issues/35 --- setup.py | 2 +- soundstorm_pytorch/soundstorm.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4ed0c22..5011590 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.4.3', + version = '0.4.6', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 5baa645..303d79c 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -530,6 +530,8 @@ def __init__( Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers) ) + self.num_effective_quantizers = num_effective_quantizers + # each quantizer codebook would require its own logits weight and bias matrices # the amazing einops makes this easy with 'EinMix' @@ -579,6 +581,9 @@ def forward( x = self.embedding_proj(x) if exists(sum_embeds): + if sum_embeds.ndim == 3: + sum_embeds = reduce(sum_embeds, 'b (n h) d -> b n d', 'sum', h = self.num_effective_quantizers) + x = x + sum_embeds if exists(cond):