Skip to content

Commit

Permalink
Added mask_fill api and some examples in README (facebookresearch#807)
Browse files Browse the repository at this point in the history
Summary:
1) This currently works only for single `<mask>` token as multi mask, we might have to look more into order of factorization.
2) This is currently only for single BPE token
Pull Request resolved: fairinternal/fairseq-py#807

Differential Revision: D16674509

fbshipit-source-id: 0a020030ee5df6a5115e5f85d5a9ef52b1ad9e1c
  • Loading branch information
Naman Goyal authored and facebook-github-bot committed Aug 7, 2019
1 parent 1e55bbd commit a9eda73
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
13 changes: 13 additions & 0 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ roberta.cuda()
roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
```

##### Filling mask:
Some examples from the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/).
```python
>>> roberta.fill_mask("The first Star wars movie came out in <mask>", topk=3)
[('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)]

>>> roberta.fill_mask("Vikram samvat calender is official in <mask>", topk=3)
[('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)]

>>> roberta.fill_mask("<mask> is the common currency of the European Union", topk=3)
[('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)]
```

##### Evaluating the `roberta.large.mnli` model

Example python code snippet to evaluate accuracy on the MNLI dev_matched set.
Expand Down
45 changes: 45 additions & 0 deletions fairseq/models/roberta/hub_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,48 @@ def extract_features_aligned_to_words(self, sentence: str, return_all_hiddens: b
assert len(doc) == aligned_feats.size(0)
doc.user_token_hooks['vector'] = lambda token: aligned_feats[token.i]
return doc

def fill_mask(self, masked_input: str, topk: int = 5):
masked_token = '<mask>'
assert masked_token in masked_input and masked_input.count(masked_token) == 1, \
"Please add one {0} token for the input, eg: 'He is a {0} guy'".format(masked_token)

text_spans = masked_input.split(masked_token)
text_spans_bpe = (' {0} '.format(masked_token)).join(
[self.bpe.encode(text_span.rstrip()) for text_span in text_spans]
).strip()
tokens = self.task.source_dictionary.encode_line(
'<s> ' + text_spans_bpe,
append_eos=True,
)

masked_index = (tokens == self.task.mask_idx).nonzero()
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)

features, extra = self.model(
tokens.long().to(device=self.device),
features_only=False,
return_all_hiddens=False,
)
logits = features[0, masked_index, :].squeeze()
prob = logits.softmax(dim=0)
values, index = prob.topk(k=topk, dim=0)
topk_predicted_token_bpe = self.task.source_dictionary.string(index)

topk_filled_outputs = []
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')):
predicted_token = self.bpe.decode(predicted_token_bpe)
if " {0}".format(masked_token) in masked_input:
topk_filled_outputs.append((
masked_input.replace(
' {0}'.format(masked_token), predicted_token
),
values[index].item(),
))
else:
topk_filled_outputs.append((
masked_input.replace(masked_token, predicted_token),
values[index].item(),
))
return topk_filled_outputs

0 comments on commit a9eda73

Please sign in to comment.