From 1616cdfa50ef80ae45f5984a6d70af8b55efcc24 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 13 Sep 2024 09:51:08 -0700 Subject: [PATCH] address https://github.com/lucidrains/soundstorm-pytorch/issues/37 --- setup.py | 2 +- soundstorm_pytorch/attend.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index cc12124..9097c4b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.4.8', + version = '0.4.9', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/attend.py b/soundstorm_pytorch/attend.py index 9d98ea9..57e114b 100644 --- a/soundstorm_pytorch/attend.py +++ b/soundstorm_pytorch/attend.py @@ -95,8 +95,10 @@ def flash_attn(self, q, k, v, mask = None, attn_bias = None): if exists(attn_bias): mask_value = -torch.finfo(q.dtype).max // 2 - causal_mask = self.get_mask(q_len, k_len, device) - attn_bias = attn_bias.masked_fill(causal_mask, mask_value) + + if causal: + causal_mask = self.get_mask(q_len, k_len, device) + attn_bias = attn_bias.masked_fill(causal_mask, mask_value) if exists(mask): attn_bias = attn_bias.masked_fill(~mask, mask_value)