Skip to content

Commit

Permalink
Fixing key padding mask during transformer generation
Browse files Browse the repository at this point in the history
Summary:
facebookresearch#1097 added key padding mask history in TransformerDecoderLayer, but during an edge case where only the current or only the previous key_padding_mask exists, the resulting key_padding_mask is the wrong size.

This diff adds empty columns in such a case to ensure key_padding_mask is a usable size.

Reviewed By: myleott

Differential Revision: D18224313

fbshipit-source-id: c9fb7266baf0a2d79a66704e00a5ea8bd2987ff6
  • Loading branch information
Spencer Poff authored and facebook-github-bot committed Nov 5, 2019
1 parent a0f7599 commit 68dd3e1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 9 deletions.
6 changes: 3 additions & 3 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,9 @@ def extract_features(
# B x T x C -> T x B x C
x = x.transpose(0, 1)

self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None
self_attn_padding_mask = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)

# decoder layers
attn = None
Expand Down
42 changes: 36 additions & 6 deletions fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,14 @@ def forward(
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
prev_key_padding_mask = saved_state['prev_key_padding_mask']
if static_kv:
key_padding_mask = prev_key_padding_mask
else:
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
key_padding_mask = self._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=saved_state.get('prev_key_padding_mask', None),
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)

saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_key_padding_mask'] = key_padding_mask
Expand Down Expand Up @@ -275,6 +277,34 @@ def forward(

return attn, attn_weights

@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask,
prev_key_padding_mask,
batch_size,
src_len,
static_kv,
):
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1)).bool()
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
key_padding_mask = torch.cat((prev_key_padding_mask, filler), dim=1)
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)).bool()
if key_padding_mask.is_cuda:
filler = filler.cuda()
key_padding_mask = torch.cat((filler, key_padding_mask), dim=1)
return key_padding_mask

def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
Expand Down
60 changes: 60 additions & 0 deletions tests/test_multihead_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import unittest
from fairseq.modules.multihead_attention import MultiheadAttention


class TestMultiheadAttention(unittest.TestCase):
def test_append_prev_key_padding_mask(self):
bsz = 1
src_len = 4

cases = [
# no padding mask
(None, None, None),
# current padding mask only
(
torch.tensor([[1]]).bool(),
None,
torch.tensor([[0, 0, 0, 1]]).bool(),
),
# previous padding mask only
(
None,
torch.tensor([[0, 1, 0]]).bool(),
torch.tensor([[0, 1, 0, 0]]).bool(),
),
# both padding masks
(
torch.tensor([[1]]).bool(),
torch.tensor([[0, 1, 0]]).bool(),
torch.tensor([[0, 1, 0, 1]]).bool(),
),
]
for c in cases:
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
c[0],
c[1],
batch_size=bsz,
src_len=src_len,
static_kv=False,
)

if key_padding_mask is not None:
self.assertTrue(
torch.all(torch.eq(key_padding_mask, c[2])),
f'Unexpected resultant key padding mask: {key_padding_mask}'
f' given current: {c[0]} and previous: {c[1]}',
)
self.assertEqual(key_padding_mask.size(0), bsz)
self.assertEqual(key_padding_mask.size(1), src_len)
else:
self.assertIsNone(c[2])


if __name__ == '__main__':
unittest.main()

0 comments on commit 68dd3e1

Please sign in to comment.