Skip to content

Commit

Permalink
Merge pull request #16 from chenht2010/fork
Browse files Browse the repository at this point in the history
update mask compute
  • Loading branch information
lucidrains authored Jun 30, 2023
2 parents dc9e7c5 + d0e13ac commit 32d31e8
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,13 +983,18 @@ def forward(
masked = torch.where(replace_mask_id_mask, self.mask_id, x[:, t:, q])
masked = rearrange(torch.cat((x[:, :t, q], masked), dim=1), 'b n -> b n 1')
masked = torch.cat((x[:, :, :q], masked, x[:, :, q+1:]), dim=2)
masked[:, t:, q+1:] = self.mask_id
masked = rearrange(masked, 'b n q -> b (n q)')

prompt_mask = torch.full((b, t), False, device=device)
lower_quantizers_mask = torch.full((b, n, q), False, device=device)
upper_quantizers_mask = torch.full((b, n, (gq - q - 1)), True, device=device)
mask = rearrange(torch.cat((prompt_mask, mask), dim=1), 'b n -> b n 1')
# upper_quantizers_mask in prompt also should be False
upper_quantizers_mask[:, :t, :] = False
mask = rearrange(torch.cat((prompt_mask, replace_mask_id_mask), dim=1), 'b n -> b n 1')
mask = torch.cat((lower_quantizers_mask, mask, upper_quantizers_mask), dim=2)
# above is the right mask, but when compute loss, only consider level q
mask[:, :, q+1:]=False
mask = rearrange(mask, 'b n q -> b (n q)')

# self conditioning
Expand Down

0 comments on commit 32d31e8

Please sign in to comment.