Skip to content

Commit

Permalink
address #40
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 14, 2024
1 parent 1616cdf commit a559705
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.9',
version = '0.4.10',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
13 changes: 9 additions & 4 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def __init__(
no_replace_prob = 0.15, # which percentage of the tokens masked will stay the same, done in original MLM paper
random_token_prob = 0.1, # which percentage of tokens to be replaced with random token, done in original MLM paper
schedule = 'linear',
can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked
can_mask_prev_unmasked = True, # when unmasking, whether it can remask previously unmasked
self_token_critic = False, # https://aclanthology.org/2021.naacl-main.409/
critic_loss_weight = 1.,
num_semantic_token_ids = None,
Expand Down Expand Up @@ -899,6 +899,8 @@ def generate(

all_mask_num_tokens = (rand_mask_probs * seq_len_from_mask).long()

prev_mask = None

# self conditioning

has_self_cond = self.self_cond
Expand Down Expand Up @@ -962,15 +964,18 @@ def generate(

scores = scores.masked_fill(~seq_mask_with_quantizer, mask_value)

if not self.can_mask_prev_unmasked:
scores = scores.masked_fill(~mask, mask_value)
if not self.can_mask_prev_unmasked and exists(prev_mask):
scores = scores.masked_fill(~prev_mask, mask_value)

scores_sorted = scores.argsort(dim = -1, descending = True).argsort(dim = -1)

mask_num_tokens = rearrange(mask_num_tokens, 'b -> b 1')

mask = scores_sorted < mask_num_tokens

if not self.can_mask_prev_unmasked:
prev_mask = mask.clone()

mask = rearrange(mask, 'b (n q) -> b n q', q = num_effective_quantizers)

# mask all upper quantizer levels
Expand All @@ -982,7 +987,7 @@ def generate(

if q_level > 0:
mask[:, :, :q_level] = False

mask = rearrange(mask, 'b n q -> b (n q)', q = num_effective_quantizers)

if exists(prompt_mask):
Expand Down

0 comments on commit a559705

Please sign in to comment.