From d0e13ac78c7fa87af74e7c01365df5d854180b87 Mon Sep 17 00:00:00 2001 From: chenhaitao Date: Fri, 30 Jun 2023 11:24:17 +0800 Subject: [PATCH] update mask compute --- soundstorm_pytorch/soundstorm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 5b21715..f35de78 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -983,13 +983,18 @@ def forward( masked = torch.where(replace_mask_id_mask, self.mask_id, x[:, t:, q]) masked = rearrange(torch.cat((x[:, :t, q], masked), dim=1), 'b n -> b n 1') masked = torch.cat((x[:, :, :q], masked, x[:, :, q+1:]), dim=2) + masked[:, t:, q+1:] = self.mask_id masked = rearrange(masked, 'b n q -> b (n q)') prompt_mask = torch.full((b, t), False, device=device) lower_quantizers_mask = torch.full((b, n, q), False, device=device) upper_quantizers_mask = torch.full((b, n, (gq - q - 1)), True, device=device) - mask = rearrange(torch.cat((prompt_mask, mask), dim=1), 'b n -> b n 1') + # upper_quantizers_mask in prompt also should be False + upper_quantizers_mask[:, :t, :] = False + mask = rearrange(torch.cat((prompt_mask, replace_mask_id_mask), dim=1), 'b n -> b n 1') mask = torch.cat((lower_quantizers_mask, mask, upper_quantizers_mask), dim=2) + # above is the right mask, but when compute loss, only consider level q + mask[:, :, q+1:]=False mask = rearrange(mask, 'b n q -> b (n q)') # self conditioning