diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 9006e4f193..1b8d637ccb 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -167,13 +167,13 @@ RoBERTa can be used to fill `` tokens in the input. Some examples from the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/): ```python roberta.fill_mask('The first Star wars movie came out in ', topk=3) -# [('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)] +# [('The first Star wars movie came out in 1977', 0.9504708051681519, ' 1977'), ('The first Star wars movie came out in 1978', 0.009986862540245056, ' 1978'), ('The first Star wars movie came out in 1979', 0.009574787691235542, ' 1979')] roberta.fill_mask('Vikram samvat calender is official in ', topk=3) -# [('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)] +# [('Vikram samvat calender is official in India', 0.21878819167613983, ' India'), ('Vikram samvat calender is official in Delhi', 0.08547237515449524, ' Delhi'), ('Vikram samvat calender is official in Gujarat', 0.07556215673685074, ' Gujarat')] roberta.fill_mask(' is the common currency of the European Union', topk=3) -# [('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)] +# [('Euro is the common currency of the European Union', 0.9456493854522705, 'Euro'), ('euro is the common currency of the European Union', 0.025748178362846375, 'euro'), ('€ is the common currency of the European Union', 0.011183084920048714, '€')] ``` #### Pronoun disambiguation (Winograd Schema Challenge): diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index e40e4ab92a..216b6fd90f 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -174,11 +174,13 @@ def fill_mask(self, masked_input: str, topk: int = 5): ' {0}'.format(masked_token), predicted_token ), values[index].item(), + predicted_token, )) else: topk_filled_outputs.append(( masked_input.replace(masked_token, predicted_token), values[index].item(), + predicted_token, )) return topk_filled_outputs