Skip to content

Commit

Permalink
address #36 again
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 12, 2024
1 parent 2b9263e commit 6b3fc09
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,21 @@ def coin_flip():
def get_mask_subset_prob(
mask: Tensor,
prob: Union[float, Tensor],
min_mask: int = 0
min_mask: int = 0,
min_keep_mask: int = 0
):
batch, seq, device = *mask.shape, mask.device

if isinstance(prob, Tensor):
prob = rearrange(prob, 'b -> b 1')

num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)
total = mask.sum(dim = -1, keepdim = True)

max_mask = (total - min_keep_mask).clamp(min = 0)

num_to_mask = (total * prob).long().clamp(min = min_mask)
num_to_mask = torch.minimum(num_to_mask, max_mask)

logits = torch.rand((batch, seq), device = device)
logits = logits.masked_fill(~mask, -1)

Expand Down Expand Up @@ -1138,7 +1145,7 @@ def forward(
replace_mask_id_mask &= ~no_replace_prob_mask

if self.random_token_prob > 0. and coin_flip():
random_token_prob_mask = get_mask_subset_prob(replace_mask_id_mask, self.random_token_prob * frac_seq_left)
random_token_prob_mask = get_mask_subset_prob(replace_mask_id_mask, self.random_token_prob * frac_seq_left, min_keep_mask = 1)
random_tokens = torch.randint(0, self.num_tokens, (b, n - t), device = device)

x[:, t:, q] = torch.where(random_token_prob_mask, random_tokens, x[:, t:, q])
Expand Down

0 comments on commit 6b3fc09

Please sign in to comment.