diff --git a/setup.py b/setup.py index a8e075b..3b29bdc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.21', + version = '0.0.22', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/attend.py b/soundstorm_pytorch/attend.py index 27cb93e..9d98ea9 100644 --- a/soundstorm_pytorch/attend.py +++ b/soundstorm_pytorch/attend.py @@ -150,6 +150,13 @@ def forward(self, q, k, v, mask = None, attn_bias = None): causal_mask = self.get_mask(q_len, k_len, device) sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + # key padding mask + + if exists(mask): + if mask.ndim != 4: + mask = rearrange(mask, 'b j -> b 1 1 j') + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + # attention attn = sim.softmax(dim=-1)