forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixing key padding mask during transformer generation
Summary: facebookresearch#1097 added key padding mask history in TransformerDecoderLayer, but during an edge case where only the current or only the previous key_padding_mask exists, the resulting key_padding_mask is the wrong size. This diff adds empty columns in such a case to ensure key_padding_mask is a usable size. Reviewed By: myleott Differential Revision: D18224313 fbshipit-source-id: c9fb7266baf0a2d79a66704e00a5ea8bd2987ff6
- Loading branch information
1 parent
a0f7599
commit 68dd3e1
Showing
3 changed files
with
99 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
import unittest | ||
from fairseq.modules.multihead_attention import MultiheadAttention | ||
|
||
|
||
class TestMultiheadAttention(unittest.TestCase): | ||
def test_append_prev_key_padding_mask(self): | ||
bsz = 1 | ||
src_len = 4 | ||
|
||
cases = [ | ||
# no padding mask | ||
(None, None, None), | ||
# current padding mask only | ||
( | ||
torch.tensor([[1]]).bool(), | ||
None, | ||
torch.tensor([[0, 0, 0, 1]]).bool(), | ||
), | ||
# previous padding mask only | ||
( | ||
None, | ||
torch.tensor([[0, 1, 0]]).bool(), | ||
torch.tensor([[0, 1, 0, 0]]).bool(), | ||
), | ||
# both padding masks | ||
( | ||
torch.tensor([[1]]).bool(), | ||
torch.tensor([[0, 1, 0]]).bool(), | ||
torch.tensor([[0, 1, 0, 1]]).bool(), | ||
), | ||
] | ||
for c in cases: | ||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( | ||
c[0], | ||
c[1], | ||
batch_size=bsz, | ||
src_len=src_len, | ||
static_kv=False, | ||
) | ||
|
||
if key_padding_mask is not None: | ||
self.assertTrue( | ||
torch.all(torch.eq(key_padding_mask, c[2])), | ||
f'Unexpected resultant key padding mask: {key_padding_mask}' | ||
f' given current: {c[0]} and previous: {c[1]}', | ||
) | ||
self.assertEqual(key_padding_mask.size(0), bsz) | ||
self.assertEqual(key_padding_mask.size(1), src_len) | ||
else: | ||
self.assertIsNone(c[2]) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |