Skip to content

Commit

Permalink
complete variable lengthed sequence training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 23, 2023
1 parent 24c624b commit fa1b7d2
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ generated_speech = model.generate(
- [x] default flash attention to true
- [x] remove batchnorm, and just use layernorm, but after the swish (as in normformer paper)
- [x] trainer with accelerate - thanks to @lucasnewman
- [x] allow for variable lengthed sequence training and generation, by passing in `mask` at `forward` and `generate`

- [ ] add an option to mask out attention to padding (variable lengthed semantic) - should work as is though
- [ ] option to return list of audio files when generating
- [ ] turn it into a command line tool
- [ ] add cross attention and adaptive layernorm conditioning
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.28',
version = '0.0.29',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
79 changes: 57 additions & 22 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from random import random
from random import random, randrange
from functools import wraps
from contextlib import nullcontext
from collections import namedtuple
Expand Down Expand Up @@ -79,8 +79,16 @@ def coin_flip():

# tensor helpers

def get_mask_subset_prob(mask, prob, min_mask = 0):
def get_mask_subset_prob(
mask: Tensor,
prob: Union[float, Tensor],
min_mask: int = 0
):
batch, seq, device = *mask.shape, mask.device

if isinstance(prob, Tensor):
prob = rearrange(prob, 'b -> b 1')

num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)
logits = torch.rand((batch, seq), device = device)
logits = logits.masked_fill(~mask, -1)
Expand Down Expand Up @@ -753,8 +761,6 @@ def generate(
text_to_semantic_generate_kwargs: dict = {},
**kwargs
):
seq_mask = mask

if self.should_condition and not exists(cond_semantic_token_ids):
assert exists(texts) and exists(self.text_to_semantic)

Expand Down Expand Up @@ -803,10 +809,19 @@ def generate(
times = torch.linspace(0., 1., self.steps + 1)

# sequence starts off as all masked
# todo: find a better name for sequence mask vs mask for mask diffusion

shape = (batch_size, seq_len)

seq = torch.full(shape, self.mask_id, device = device)

seq_mask = mask

if not exists(seq_mask):
seq_mask = torch.ones((batch_size, num_latents), device = device, dtype = torch.bool)

seq_mask_with_quantizer = repeat(seq_mask, 'b n -> b (n q)', q = num_effective_quantizers)

mask = torch.full(shape, True, device = device)

# include prompt tokens unmasked as the sequence prefix, starting from the lowest quantizer
Expand All @@ -828,9 +843,14 @@ def generate(

# slowly demask

all_mask_num_tokens = (self.schedule_fn(times[1:]) * seq_len).long()

mask_num_tokens_for_q_level = [all_mask_num_tokens if q < num_full_sampling_levels else torch.zeros(1, dtype = torch.long, device = device) for q in range(num_effective_quantizers)]
seq_len_from_mask = reduce(seq_mask, 'b n -> b', 'sum')

rand_mask_probs = self.schedule_fn(times[1:])
rand_mask_probs = rearrange(rand_mask_probs, 'n -> n 1')

all_mask_num_tokens = (rand_mask_probs * seq_len_from_mask).long()

mask_num_tokens_for_q_level = [all_mask_num_tokens if q < num_full_sampling_levels else torch.zeros((self.steps, batch_size), dtype = torch.long, device = device) for q in range(num_effective_quantizers)]

# self conditioning

Expand All @@ -839,7 +859,7 @@ def generate(

for q_level in range(num_effective_quantizers):

for mask_num_tokens, steps_until_x0 in tqdm(zip(mask_num_tokens_for_q_level[q_level].tolist(), reversed(range(self.steps))), total = self.steps):
for mask_num_tokens, steps_until_x0 in tqdm(zip(mask_num_tokens_for_q_level[q_level], reversed(range(self.steps))), total = self.steps):

self_cond = self.to_self_cond(last_embed) if has_self_cond else None

Expand Down Expand Up @@ -886,18 +906,26 @@ def generate(
mask_indices = torch.zeros(batch_size, 1, dtype = torch.long, device = device)
mask = torch.zeros_like(scores, dtype = torch.bool)

if mask_num_tokens != 0:
if not self.can_mask_prev_unmasked:
scores = scores.masked_fill(~mask, -torch.finfo(scores.dtype).max)
# mask based on highest score

mask_value = -torch.finfo(scores.dtype).max

scores = scores.masked_fill(~seq_mask_with_quantizer, mask_value)

if not self.can_mask_prev_unmasked:
scores = scores.masked_fill(~mask, mask_value)

scores_sorted = scores.argsort(dim = -1, descending = True)

mask_num_tokens = rearrange(mask_num_tokens, 'b -> b 1')

mask = scores_sorted < mask_num_tokens

mask_indices = scores.topk(mask_num_tokens, dim = -1).indices
mask = mask.scatter(1, mask_indices, True)

mask = rearrange(mask, 'b (n q) -> b n q', q = num_effective_quantizers)

# mask all upper quantizer levels

if q_level < num_effective_quantizers - 1:
if q_level < (num_effective_quantizers - 1):
mask[:, :, q_level + 1:] = True

# unmask all lower quantizer levels
Expand All @@ -906,7 +934,7 @@ def generate(
mask[:, :, :q_level] = False

mask = rearrange(mask, 'b n q -> b (n q)', q = num_effective_quantizers)

if exists(prompt_mask):
mask = mask & prompt_mask

Expand Down Expand Up @@ -1021,6 +1049,9 @@ def forward(

seq_mask = mask

if not exists(seq_mask):
seq_mask = torch.ones((b, n), device = device, dtype = torch.bool)

# maybe condition

cond_tokens = self.maybe_get_condition(cond_semantic_token_ids, length = x.shape[-2])
Expand All @@ -1029,15 +1060,19 @@ def forward(

orig_seq = rearrange(x.clone(), 'b n q -> b (n q)')

t = torch.randint(0, n, (1,)).item()
q = torch.randint(0, gq, (1,)).item()
min_seq_len = seq_mask.sum(dim = -1).amin()
t = randrange(0, min_seq_len - 1)

rand_times = torch.empty(b, device = device).uniform_(0, 1)
batched_randperm = torch.rand((b, n - t), device = device).argsort(dim = -1).float()
mask = seq_mask[:, t:]

rand_times = torch.empty(b, device = device).uniform_(0, 1)
rand_probs = self.schedule_fn(rand_times)
num_tokens_mask = (rand_probs * (n - t)).clamp(min = 1.)
mask = batched_randperm < rearrange(num_tokens_mask, 'b -> b 1')

mask = get_mask_subset_prob(mask, rand_probs)

# random quantizer position, in groups

q = randrange(0, self.num_quantizers) * self.grouped_quantizers

# to ensure all tokens produce embeddings, instead of just the ones with [mask] input, as done in seminal BERT MLM paper
# potentially needed for self-conditioning (on embedding) to work well
Expand Down

0 comments on commit fa1b7d2

Please sign in to comment.