Skip to content

Commit 9991c8d

Browse files
authored
fixed masked flash attention (#2589)
* fixed masked flash attention
1 parent 63f07fc commit 9991c8d

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

onmt/modules/multi_headed_attn.py

+10-20
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ def forward(
439439
"""
440440
# 1) Project key, value, and query.
441441
# as a reminder at training layer_cache[0] remains False
442-
key_pad_mask = self.layer_cache[1].get("key_pad_mask", None)
443442
if self.layer_cache[0]:
444443
# Retrieve keys and values from the KV cache (decoding mode only).
445444
if self.attn_type == "self":
@@ -484,6 +483,16 @@ def forward(
484483
key = key[:, :, 1:, :]
485484
value = value[:, :, 1:, :]
486485

486+
if step == 0:
487+
key_pad_mask = self.layer_cache[1].get("key_pad_mask", None)
488+
if key_pad_mask is not None:
489+
x = key_pad_mask.expand(
490+
-1, self.head_count // self.parallel_gpu, -1
491+
)
492+
x = x.unsqueeze(3)
493+
x = x.expand(-1, -1, -1, value.size(3))
494+
value = value.masked_fill(x, 0)
495+
487496
self.layer_cache[1]["keys"] = key
488497
self.layer_cache[1]["values"] = value
489498

@@ -565,19 +574,6 @@ def forward(
565574
self.layer_cache[1]["keys"] = key
566575
self.layer_cache[1]["values"] = value
567576

568-
if key_pad_mask is not None:
569-
# Increase the cached key pad mask by concatenation.
570-
# For decoding only.
571-
if step > 0:
572-
y = torch.zeros(
573-
(key_pad_mask.size(0), key_pad_mask.size(1), 1),
574-
dtype=torch.bool,
575-
device=key_pad_mask.device,
576-
)
577-
self.layer_cache[1]["key_pad_mask"] = torch.cat(
578-
(key_pad_mask, y), 2
579-
)
580-
key_pad_mask = self.layer_cache[1]["key_pad_mask"]
581577
else:
582578
# Retrieve keys and values from linear layers (training mode).
583579
key = self.maybe_ckpt(self.linear_keys, key)
@@ -706,8 +702,6 @@ def forward(
706702
scores = self.alibi(scores)
707703

708704
scores = scores.float()
709-
if key_pad_mask is not None and mask is None:
710-
mask = key_pad_mask.unsqueeze(1)
711705

712706
if mask is not None:
713707
# not 100% necessary but expand to nb of heads
@@ -727,10 +721,6 @@ def forward(
727721
attn_output.add_(relative_matmul(drop_attn, relations_values, False))
728722

729723
context = unshape(attn_output)
730-
if key_pad_mask is not None:
731-
if key_pad_mask.size(0) > 1 and context.size(1) > 1:
732-
x = key_pad_mask.squeeze(1).unsqueeze(2).expand(-1, -1, context.size(2))
733-
context = context.masked_fill(x, 0)
734724

735725
if self.layer_cache[0]:
736726
attn_output = self.final_linear(context)

0 commit comments

Comments
 (0)