Skip to content

Commit

Permalink
complete grouped residual vq support
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 19, 2023
1 parent 2794f9d commit eb32363
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ generated_audio = model.generate(seconds = 30, batch_size = 2) # generate 30 se

- [x] integrate soundstream
- [x] when generating, and length can be defined in seconds (takes into sampling freq etc)
- [x] make sure grouped rvq is supported. concat embeddings rather than sum across group dimension

- [ ] option to return list of audio files when generating
- [ ] turn it into a command line tool
- [ ] add cross attention and adaptive layernorm conditioning (just copy paste in the entire conformer repository, if conditioning adds too much cruft to the other repo)
- [ ] make sure grouped rvq is supported. concat embeddings rather than sum across group dimension
- [ ] trainer with accelerate

## Citations
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.10',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
76 changes: 50 additions & 26 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from audiolm_pytorch import SoundStream

from tqdm import tqdm

# helpers

def exists(val):
Expand Down Expand Up @@ -104,7 +106,7 @@ def __init__(
codebook_size,
num_quantizers,
conformer: Union[Conformer, Dict[str, any]],
num_tokens_per_head = None,
grouped_quantizers = 1
):
super().__init__()
self.conformer = conformer
Expand All @@ -114,36 +116,42 @@ def __init__(

dim = self.conformer.dim

self.embedding_proj = nn.Sequential(
nn.Linear(dim * grouped_quantizers, dim),
nn.LayerNorm(dim)
) if grouped_quantizers > 1 else nn.Identity()

num_codes_with_mask = codebook_size + 1
num_effective_quantizers = num_quantizers * grouped_quantizers

self.code_embeds = nn.Embedding(num_codes_with_mask * num_quantizers, dim)
self.code_embeds = nn.Embedding(num_codes_with_mask * num_effective_quantizers, dim)

self.register_buffer('quantizer_offsets', torch.arange(num_quantizers) * num_codes_with_mask, persistent = False)
self.register_buffer('quantizer_offsets', torch.arange(num_effective_quantizers) * num_codes_with_mask, persistent = False)
self.register_buffer('mask_tokens', self.quantizer_offsets + num_codes_with_mask, persistent = False)

self.dim = dim
self.codebook_size = codebook_size

self.num_codes_with_mask = num_codes_with_mask
self.num_quantizers = num_quantizers

self.num_tokens_per_head = default(num_tokens_per_head, num_quantizers)
self.grouped_quantizers = grouped_quantizers

self.heads = nn.Sequential(
nn.Linear(dim, dim * self.num_tokens_per_head),
Rearrange('b n (h d) -> b (n h) d', h = self.num_tokens_per_head)
nn.Linear(dim, dim * num_effective_quantizers),
Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers)
)

# each quantizer codebook would require its own logits weight and bias matrices
# the amazing einops makes this easy with 'EinMix'

self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b (n q) d -> b n q d', q = self.num_quantizers),
Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers),
EinMix(
'b n q d -> b n q l',
weight_shape = 'q d l',
bias_shape = 'q l',
q = self.num_quantizers,
'b n gq d -> b n gq l',
weight_shape = 'gq d l',
bias_shape = 'gq l',
gq = num_effective_quantizers,
l = codebook_size,
d = dim
),
Expand All @@ -158,27 +166,38 @@ def forward(
return_embeddings = False,
return_logits_and_embeddings = False
):
n, q = x.shape[-1], self.num_quantizers
assert divisible_by(n, q), 'sequence must be divisible by number of quantizers'

x = rearrange(x, 'b (n q) -> b n q', q = q)
"""
einops notation:
b - batch
n - sequence
g - groups
q - quantizers
d - feature dimension
"""

n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers
assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers'

x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q)
x = x + self.quantizer_offsets

x = self.code_embeds(x)

x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g)

x = self.embedding_proj(x)

if exists(sum_embeds):
x = x + sum_embeds

x = reduce(x, 'b n q d -> b n d', 'sum')

if exists(cond):
if cond.ndim == 2:
cond = rearrange(cond, 'b d -> b 1 d')

x = x + cond

logits = self.conformer(x)
embeds = self.heads(logits)
x = self.conformer(x)
embeds = self.heads(x)

if return_embeddings or not exists(self.to_logits):
return embeds
Expand Down Expand Up @@ -233,7 +252,7 @@ def __init__(
self.soundstream = soundstream

if exists(self.soundstream):
assert soundstream.rq_groups == 1, 'grouped residual vector quantized soundstream not supported, yet'
assert net.grouped_quantizers == soundstream.rq_groups
assert net.codebook_size == soundstream.codebook_size
assert net.num_quantizers == soundstream.num_quantizers

Expand All @@ -245,6 +264,7 @@ def __init__(
self.dim = dim
self.num_tokens = net.codebook_size
self.num_quantizers = net.num_quantizers
self.grouped_quantizers = net.grouped_quantizers

self.mask_id = net.codebook_size

Expand Down Expand Up @@ -288,7 +308,7 @@ def __init__(
@eval_decorator
def generate(
self,
seq_len = None,
num_latents = None,
*,
seconds = None,
batch_size = None,
Expand All @@ -297,11 +317,11 @@ def generate(
noise_level_scale = 1.,
**kwargs
):
assert exists(seq_len) ^ exists(seconds)
assert exists(num_latents) ^ exists(seconds)

if not exists(seq_len):
if not exists(num_latents):
assert exists(self.soundstream), 'soundstream must be passed in to generate in seconds'
seq_len = (seconds * self.soundstream.target_sample_hz) // self.soundstream.seq_len_multiple_of
num_latents = (seconds * self.soundstream.target_sample_hz) // self.soundstream.seq_len_multiple_of

sample_one = not exists(batch_size)
batch_size = default(batch_size, 1)
Expand All @@ -310,6 +330,10 @@ def generate(

times = torch.linspace(0., 1., self.steps + 1)

# sequence length of the conformer is the number of latents

seq_len = num_latents * self.grouped_quantizers * self.num_quantizers

# sequence starts off as all masked

shape = (batch_size, seq_len)
Expand All @@ -326,7 +350,7 @@ def generate(
has_self_cond = self.self_cond
last_embed = self.null_embed if has_self_cond else None

for mask_num_tokens, steps_until_x0 in zip(all_mask_num_tokens.tolist(), reversed(range(self.steps))):
for mask_num_tokens, steps_until_x0 in tqdm(zip(all_mask_num_tokens.tolist(), reversed(range(self.steps))), total = self.steps):

self_cond = self.to_self_cond(last_embed) if has_self_cond else None

Expand Down

0 comments on commit eb32363

Please sign in to comment.