From 0575c81b98eb0f82f44cf44437ba40d0b6ae779a Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 21 Aug 2023 14:53:57 -0700 Subject: [PATCH] Add support for generating each quantizer level one-by-one, from coarse to fine. Also add support for a single sampling step for a level, as only the first level is decoded with the full number of forward passes in the paper. --- soundstorm_pytorch/soundstorm.py | 134 +++++++++++++++++++------------ 1 file changed, 83 insertions(+), 51 deletions(-) diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 83f1688..125de2e 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -422,9 +422,11 @@ def __init__( ff_mult = ff_mult, conv_expansion_factor = conv_expansion_factor, conv_kernel_size = conv_kernel_size, + attn_dropout = attn_dropout, + ff_dropout = ff_dropout, + conv_dropout = conv_dropout, conv_causal = conv_causal, attn_flash = attn_flash - )) def forward(self, x): @@ -739,6 +741,7 @@ def generate( start_temperature = 1., filter_thres = 0.7, noise_level_scale = 1., + num_full_sampling_levels = 1, text_to_semantic_generate_kwargs: dict = {}, **kwargs ): @@ -817,59 +820,87 @@ def generate( # slowly demask all_mask_num_tokens = (self.schedule_fn(times[1:]) * seq_len).long() + + mask_num_tokens_for_q_level = [all_mask_num_tokens if q < num_full_sampling_levels else torch.zeros(1, dtype = torch.long, device = device) for q in range(num_effective_quantizers)] # self conditioning has_self_cond = self.self_cond last_embed = self.null_embed if has_self_cond else None - - for mask_num_tokens, steps_until_x0 in tqdm(zip(all_mask_num_tokens.tolist(), reversed(range(self.steps))), total = self.steps): - - self_cond = self.to_self_cond(last_embed) if has_self_cond else None - - logits, embeds = self.net( - seq, - cond = cond_tokens, - sum_embeds = self_cond, - return_logits_and_embeddings = True, - **kwargs - ) - - if has_self_cond: - last_embed = embeds - - if exists(filter_thres): - logits = top_k(logits, filter_thres) - - annealing_scale = steps_until_x0 / self.steps - temperature = start_temperature * annealing_scale - - sampled_ids = gumbel_sample(logits, temperature = max(temperature, 1e-3)) - - seq = torch.where(mask, sampled_ids, seq) - - if exists(self.token_critic): - scores = self.token_critic(seq) - scores = rearrange(scores, 'b n 1 -> b n') - scores = scores + noise_level_scale * gumbel_noise(scores) * annealing_scale - else: - scores = 1 - logits.softmax(dim = -1) - scores = scores.gather(2, rearrange(sampled_ids, 'b n -> b n 1')) - scores = rearrange(scores, 'b n 1 -> b n') - - if mask_num_tokens == 0: - pass - - if not self.can_mask_prev_unmasked: - scores = scores.masked_fill(~mask, -torch.finfo(scores.dtype).max) - - mask_indices = scores.topk(mask_num_tokens, dim = -1).indices - mask = torch.zeros_like(scores, dtype = torch.bool).scatter(1, mask_indices, True) - - if exists(prompt_mask): - mask = mask & prompt_mask + + for q_level in range(num_effective_quantizers): - seq = seq.masked_fill(mask, self.mask_id) + for mask_num_tokens, steps_until_x0 in tqdm(zip(mask_num_tokens_for_q_level[q_level].tolist(), reversed(range(self.steps))), total = self.steps): + + self_cond = self.to_self_cond(last_embed) if has_self_cond else None + + logits, embeds = self.net( + seq, + cond = cond_tokens, + sum_embeds = self_cond, + return_logits_and_embeddings = True, + **kwargs + ) + + if has_self_cond: + last_embed = embeds + + if exists(filter_thres): + logits = top_k(logits, filter_thres) + + annealing_scale = steps_until_x0 / self.steps + temperature = start_temperature * annealing_scale + + sampled_ids = gumbel_sample(logits, temperature = max(temperature, 1e-3)) + + # don't sample for lower quantizer levels + + if q_level > 0: + sample_mask = rearrange(mask, 'b (n q) -> b n q', q = num_effective_quantizers) + sample_mask[:, :, :q_level] = False + sample_mask = rearrange(sample_mask, 'b n q -> b (n q)', q = num_effective_quantizers) + else: + sample_mask = mask + + seq = torch.where(sample_mask, sampled_ids, seq) + + if exists(self.token_critic): + scores = self.token_critic(seq) + scores = rearrange(scores, 'b n 1 -> b n') + scores = scores + noise_level_scale * gumbel_noise(scores) * annealing_scale + else: + scores = 1 - logits.softmax(dim = -1) + scores = scores.gather(2, rearrange(sampled_ids, 'b n -> b n 1')) + scores = rearrange(scores, 'b n 1 -> b n') + + mask_indices = torch.zeros(batch_size, 1, dtype = torch.long, device = device) + mask = torch.zeros_like(scores, dtype = torch.bool) + + if mask_num_tokens != 0: + if not self.can_mask_prev_unmasked: + scores = scores.masked_fill(~mask, -torch.finfo(scores.dtype).max) + + mask_indices = scores.topk(mask_num_tokens, dim = -1).indices + mask = mask.scatter(1, mask_indices, True) + + mask = rearrange(mask, 'b (n q) -> b n q', q = num_effective_quantizers) + + # mask all upper quantizer levels + + if q_level < num_effective_quantizers - 1: + mask[:, :, q_level + 1:] = True + + # unmask all lower quantizer levels + + if q_level > 0: + mask[:, :, :q_level] = False + + mask = rearrange(mask, 'b n q -> b (n q)', q = num_effective_quantizers) + + if exists(prompt_mask): + mask = mask & prompt_mask + + seq = seq.masked_fill(mask, self.mask_id) out = seq @@ -921,9 +952,10 @@ def maybe_get_condition(self, token_ids = None, length = None): # pytorch does not interpolate 1d, so hack by convert to 2d - cond_tokens = rearrange(cond_tokens, 'b n d -> b d n 1') - cond_tokens = F.interpolate(cond_tokens, (target_cond_length, 1), mode = 'bilinear') - cond_tokens = rearrange(cond_tokens, 'b d n 1 -> b n d') + if cond_length != target_cond_length: + cond_tokens = rearrange(cond_tokens, 'b n d -> b d n 1') + cond_tokens = F.interpolate(cond_tokens, (target_cond_length, 1), mode = 'bilinear') + cond_tokens = rearrange(cond_tokens, 'b d n 1 -> b n d') # whether to curtail or pad to length