diff --git a/setup.py b/setup.py index 9097c4b..92a4bb7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.4.9', + version = '0.4.10', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 1015e8d..b7aaed3 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -649,7 +649,7 @@ def __init__( no_replace_prob = 0.15, # which percentage of the tokens masked will stay the same, done in original MLM paper random_token_prob = 0.1, # which percentage of tokens to be replaced with random token, done in original MLM paper schedule = 'linear', - can_mask_prev_unmasked = False, # when unmasking, whether it can remask previously unmasked + can_mask_prev_unmasked = True, # when unmasking, whether it can remask previously unmasked self_token_critic = False, # https://aclanthology.org/2021.naacl-main.409/ critic_loss_weight = 1., num_semantic_token_ids = None, @@ -899,6 +899,8 @@ def generate( all_mask_num_tokens = (rand_mask_probs * seq_len_from_mask).long() + prev_mask = None + # self conditioning has_self_cond = self.self_cond @@ -962,8 +964,8 @@ def generate( scores = scores.masked_fill(~seq_mask_with_quantizer, mask_value) - if not self.can_mask_prev_unmasked: - scores = scores.masked_fill(~mask, mask_value) + if not self.can_mask_prev_unmasked and exists(prev_mask): + scores = scores.masked_fill(~prev_mask, mask_value) scores_sorted = scores.argsort(dim = -1, descending = True).argsort(dim = -1) @@ -971,6 +973,9 @@ def generate( mask = scores_sorted < mask_num_tokens + if not self.can_mask_prev_unmasked: + prev_mask = mask.clone() + mask = rearrange(mask, 'b (n q) -> b n q', q = num_effective_quantizers) # mask all upper quantizer levels @@ -982,7 +987,7 @@ def generate( if q_level > 0: mask[:, :, :q_level] = False - + mask = rearrange(mask, 'b n q -> b (n q)', q = num_effective_quantizers) if exists(prompt_mask):