Skip to content

Commit

Permalink
Merge pull request #20 from lucasnewman/quantizer-level-generation
Browse files Browse the repository at this point in the history
Add support for generating quantizer codes one level at a time
  • Loading branch information
lucidrains authored Aug 22, 2023
2 parents 9c92550 + 0575c81 commit 0b43936
Showing 1 changed file with 83 additions and 51 deletions.
134 changes: 83 additions & 51 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,11 @@ def __init__(
ff_mult = ff_mult,
conv_expansion_factor = conv_expansion_factor,
conv_kernel_size = conv_kernel_size,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
conv_dropout = conv_dropout,
conv_causal = conv_causal,
attn_flash = attn_flash

))

def forward(self, x):
Expand Down Expand Up @@ -739,6 +741,7 @@ def generate(
start_temperature = 1.,
filter_thres = 0.7,
noise_level_scale = 1.,
num_full_sampling_levels = 1,
text_to_semantic_generate_kwargs: dict = {},
**kwargs
):
Expand Down Expand Up @@ -817,59 +820,87 @@ def generate(
# slowly demask

all_mask_num_tokens = (self.schedule_fn(times[1:]) * seq_len).long()

mask_num_tokens_for_q_level = [all_mask_num_tokens if q < num_full_sampling_levels else torch.zeros(1, dtype = torch.long, device = device) for q in range(num_effective_quantizers)]

# self conditioning

has_self_cond = self.self_cond
last_embed = self.null_embed if has_self_cond else None

for mask_num_tokens, steps_until_x0 in tqdm(zip(all_mask_num_tokens.tolist(), reversed(range(self.steps))), total = self.steps):

self_cond = self.to_self_cond(last_embed) if has_self_cond else None

logits, embeds = self.net(
seq,
cond = cond_tokens,
sum_embeds = self_cond,
return_logits_and_embeddings = True,
**kwargs
)

if has_self_cond:
last_embed = embeds

if exists(filter_thres):
logits = top_k(logits, filter_thres)

annealing_scale = steps_until_x0 / self.steps
temperature = start_temperature * annealing_scale

sampled_ids = gumbel_sample(logits, temperature = max(temperature, 1e-3))

seq = torch.where(mask, sampled_ids, seq)

if exists(self.token_critic):
scores = self.token_critic(seq)
scores = rearrange(scores, 'b n 1 -> b n')
scores = scores + noise_level_scale * gumbel_noise(scores) * annealing_scale
else:
scores = 1 - logits.softmax(dim = -1)
scores = scores.gather(2, rearrange(sampled_ids, 'b n -> b n 1'))
scores = rearrange(scores, 'b n 1 -> b n')

if mask_num_tokens == 0:
pass

if not self.can_mask_prev_unmasked:
scores = scores.masked_fill(~mask, -torch.finfo(scores.dtype).max)

mask_indices = scores.topk(mask_num_tokens, dim = -1).indices
mask = torch.zeros_like(scores, dtype = torch.bool).scatter(1, mask_indices, True)

if exists(prompt_mask):
mask = mask & prompt_mask

for q_level in range(num_effective_quantizers):

seq = seq.masked_fill(mask, self.mask_id)
for mask_num_tokens, steps_until_x0 in tqdm(zip(mask_num_tokens_for_q_level[q_level].tolist(), reversed(range(self.steps))), total = self.steps):

self_cond = self.to_self_cond(last_embed) if has_self_cond else None

logits, embeds = self.net(
seq,
cond = cond_tokens,
sum_embeds = self_cond,
return_logits_and_embeddings = True,
**kwargs
)

if has_self_cond:
last_embed = embeds

if exists(filter_thres):
logits = top_k(logits, filter_thres)

annealing_scale = steps_until_x0 / self.steps
temperature = start_temperature * annealing_scale

sampled_ids = gumbel_sample(logits, temperature = max(temperature, 1e-3))

# don't sample for lower quantizer levels

if q_level > 0:
sample_mask = rearrange(mask, 'b (n q) -> b n q', q = num_effective_quantizers)
sample_mask[:, :, :q_level] = False
sample_mask = rearrange(sample_mask, 'b n q -> b (n q)', q = num_effective_quantizers)
else:
sample_mask = mask

seq = torch.where(sample_mask, sampled_ids, seq)

if exists(self.token_critic):
scores = self.token_critic(seq)
scores = rearrange(scores, 'b n 1 -> b n')
scores = scores + noise_level_scale * gumbel_noise(scores) * annealing_scale
else:
scores = 1 - logits.softmax(dim = -1)
scores = scores.gather(2, rearrange(sampled_ids, 'b n -> b n 1'))
scores = rearrange(scores, 'b n 1 -> b n')

mask_indices = torch.zeros(batch_size, 1, dtype = torch.long, device = device)
mask = torch.zeros_like(scores, dtype = torch.bool)

if mask_num_tokens != 0:
if not self.can_mask_prev_unmasked:
scores = scores.masked_fill(~mask, -torch.finfo(scores.dtype).max)

mask_indices = scores.topk(mask_num_tokens, dim = -1).indices
mask = mask.scatter(1, mask_indices, True)

mask = rearrange(mask, 'b (n q) -> b n q', q = num_effective_quantizers)

# mask all upper quantizer levels

if q_level < num_effective_quantizers - 1:
mask[:, :, q_level + 1:] = True

# unmask all lower quantizer levels

if q_level > 0:
mask[:, :, :q_level] = False

mask = rearrange(mask, 'b n q -> b (n q)', q = num_effective_quantizers)

if exists(prompt_mask):
mask = mask & prompt_mask

seq = seq.masked_fill(mask, self.mask_id)

out = seq

Expand Down Expand Up @@ -921,9 +952,10 @@ def maybe_get_condition(self, token_ids = None, length = None):

# pytorch does not interpolate 1d, so hack by convert to 2d

cond_tokens = rearrange(cond_tokens, 'b n d -> b d n 1')
cond_tokens = F.interpolate(cond_tokens, (target_cond_length, 1), mode = 'bilinear')
cond_tokens = rearrange(cond_tokens, 'b d n 1 -> b n d')
if cond_length != target_cond_length:
cond_tokens = rearrange(cond_tokens, 'b n d -> b d n 1')
cond_tokens = F.interpolate(cond_tokens, (target_cond_length, 1), mode = 'bilinear')
cond_tokens = rearrange(cond_tokens, 'b d n 1 -> b n d')

# whether to curtail or pad to length

Expand Down

0 comments on commit 0b43936

Please sign in to comment.