diff --git a/setup.py b/setup.py index cc12124..9097c4b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.4.8', + version = '0.4.9', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/attend.py b/soundstorm_pytorch/attend.py index 9d98ea9..57e114b 100644 --- a/soundstorm_pytorch/attend.py +++ b/soundstorm_pytorch/attend.py @@ -95,8 +95,10 @@ def flash_attn(self, q, k, v, mask = None, attn_bias = None): if exists(attn_bias): mask_value = -torch.finfo(q.dtype).max // 2 - causal_mask = self.get_mask(q_len, k_len, device) - attn_bias = attn_bias.masked_fill(causal_mask, mask_value) + + if causal: + causal_mask = self.get_mask(q_len, k_len, device) + attn_bias = attn_bias.masked_fill(causal_mask, mask_value) if exists(mask): attn_bias = attn_bias.masked_fill(~mask, mask_value)