Skip to content

Commit

Permalink
Return predicted token for RoBERTa filling mask
Browse files Browse the repository at this point in the history
Summary:
Added the `predicted_token` to each `topk` filled output item

Updated RoBERTa filling mask example in README.md

Reviewed By: myleott

Differential Revision: D17188810

fbshipit-source-id: 5fdc57ff2c13239dabf13a8dad43ae9a55e8931c
  • Loading branch information
raedle authored and facebook-github-bot committed Sep 5, 2019
1 parent 1566cfb commit 3e3fe72
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the
[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
```python
roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3)
# [('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)]
# [('The first Star wars movie came out in 1977', 0.9504708051681519, ' 1977'), ('The first Star wars movie came out in 1978', 0.009986862540245056, ' 1978'), ('The first Star wars movie came out in 1979', 0.009574787691235542, ' 1979')]

roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3)
# [('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)]
# [('Vikram samvat calender is official in India', 0.21878819167613983, ' India'), ('Vikram samvat calender is official in Delhi', 0.08547237515449524, ' Delhi'), ('Vikram samvat calender is official in Gujarat', 0.07556215673685074, ' Gujarat')]

roberta.fill_mask('<mask> is the common currency of the European Union', topk=3)
# [('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)]
# [('Euro is the common currency of the European Union', 0.9456493854522705, 'Euro'), ('euro is the common currency of the European Union', 0.025748178362846375, 'euro'), ('€ is the common currency of the European Union', 0.011183084920048714, '€')]
```

#### Pronoun disambiguation (Winograd Schema Challenge):
Expand Down
2 changes: 2 additions & 0 deletions fairseq/models/roberta/hub_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ def fill_mask(self, masked_input: str, topk: int = 5):
' {0}'.format(masked_token), predicted_token
),
values[index].item(),
predicted_token,
))
else:
topk_filled_outputs.append((
masked_input.replace(masked_token, predicted_token),
values[index].item(),
predicted_token,
))
return topk_filled_outputs

Expand Down

0 comments on commit 3e3fe72

Please sign in to comment.