From 6b3fc09135713cece689c9ec8a6ca7844c10579e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 12 Sep 2024 11:51:17 -0700 Subject: [PATCH] address https://github.com/lucidrains/soundstorm-pytorch/issues/36 again --- soundstorm_pytorch/soundstorm.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 733286c..1015e8d 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -86,14 +86,21 @@ def coin_flip(): def get_mask_subset_prob( mask: Tensor, prob: Union[float, Tensor], - min_mask: int = 0 + min_mask: int = 0, + min_keep_mask: int = 0 ): batch, seq, device = *mask.shape, mask.device if isinstance(prob, Tensor): prob = rearrange(prob, 'b -> b 1') - num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask) + total = mask.sum(dim = -1, keepdim = True) + + max_mask = (total - min_keep_mask).clamp(min = 0) + + num_to_mask = (total * prob).long().clamp(min = min_mask) + num_to_mask = torch.minimum(num_to_mask, max_mask) + logits = torch.rand((batch, seq), device = device) logits = logits.masked_fill(~mask, -1) @@ -1138,7 +1145,7 @@ def forward( replace_mask_id_mask &= ~no_replace_prob_mask if self.random_token_prob > 0. and coin_flip(): - random_token_prob_mask = get_mask_subset_prob(replace_mask_id_mask, self.random_token_prob * frac_seq_left) + random_token_prob_mask = get_mask_subset_prob(replace_mask_id_mask, self.random_token_prob * frac_seq_left, min_keep_mask = 1) random_tokens = torch.randint(0, self.num_tokens, (b, n - t), device = device) x[:, t:, q] = torch.where(random_token_prob_mask, random_tokens, x[:, t:, q])