Skip to content

Commit

Permalink
Update beam search code to support torch.bool change
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#797

Differential Revision: D16617067

Pulled By: myleott

fbshipit-source-id: 52e3aeb98d6e3b55ff9154b784028bf13eabfe38
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Aug 2, 2019
1 parent ccb5dea commit 5f34252
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fairseq/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _sample_topp(self, lprobs):

# trim the words that are not in top-P by setting their probabilities
# to 0, so that they would not be sampled later.
trim_mask = truncated_mask.bitwise_not()
trim_mask = (~truncated_mask)
trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
return trimed_probs, truncated_indices

Expand Down
2 changes: 1 addition & 1 deletion fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def generate(
# For example, suppose we're sampling and have already finalized 2/5
# samples. Then the blacklist would mark 2 positions as being ignored,
# so that we only finalize the remaining 3 samples.
blacklist = src_tokens.new(bsz, beam_size).byte().fill_(0)
blacklist = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask

# list of completed sentences
finalized = [[] for i in range(bsz)]
Expand Down

0 comments on commit 5f34252

Please sign in to comment.