Skip to content

Commit

Permalink
fix key padding mask for non-flash attention, readying for variable l…
Browse files Browse the repository at this point in the history
…ength training
  • Loading branch information
lucidrains committed Aug 1, 2023
1 parent e336755 commit e7e7dc4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.21',
version = '0.0.22',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
7 changes: 7 additions & 0 deletions soundstorm_pytorch/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def forward(self, q, k, v, mask = None, attn_bias = None):
causal_mask = self.get_mask(q_len, k_len, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

# key padding mask

if exists(mask):
if mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

# attention

attn = sim.softmax(dim=-1)
Expand Down

0 comments on commit e7e7dc4

Please sign in to comment.