diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 21b04c845a..537c55f3fa 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -134,6 +134,19 @@ roberta.cuda() roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=) ``` +##### Filling mask: +Some examples from the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/). +```python +>>> roberta.fill_mask("The first Star wars movie came out in ", topk=3) +[('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)] + +>>> roberta.fill_mask("Vikram samvat calender is official in ", topk=3) +[('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)] + +>>> roberta.fill_mask(" is the common currency of the European Union", topk=3) +[('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)] +``` + ##### Evaluating the `roberta.large.mnli` model Example python code snippet to evaluate accuracy on the MNLI dev_matched set. diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index f7ba231e3b..22ce96e89f 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -133,3 +133,48 @@ def extract_features_aligned_to_words(self, sentence: str, return_all_hiddens: b assert len(doc) == aligned_feats.size(0) doc.user_token_hooks['vector'] = lambda token: aligned_feats[token.i] return doc + + def fill_mask(self, masked_input: str, topk: int = 5): + masked_token = '' + assert masked_token in masked_input and masked_input.count(masked_token) == 1, \ + "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(masked_token) + + text_spans = masked_input.split(masked_token) + text_spans_bpe = (' {0} '.format(masked_token)).join( + [self.bpe.encode(text_span.rstrip()) for text_span in text_spans] + ).strip() + tokens = self.task.source_dictionary.encode_line( + ' ' + text_spans_bpe, + append_eos=True, + ) + + masked_index = (tokens == self.task.mask_idx).nonzero() + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + + features, extra = self.model( + tokens.long().to(device=self.device), + features_only=False, + return_all_hiddens=False, + ) + logits = features[0, masked_index, :].squeeze() + prob = logits.softmax(dim=0) + values, index = prob.topk(k=topk, dim=0) + topk_predicted_token_bpe = self.task.source_dictionary.string(index) + + topk_filled_outputs = [] + for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')): + predicted_token = self.bpe.decode(predicted_token_bpe) + if " {0}".format(masked_token) in masked_input: + topk_filled_outputs.append(( + masked_input.replace( + ' {0}'.format(masked_token), predicted_token + ), + values[index].item(), + )) + else: + topk_filled_outputs.append(( + masked_input.replace(masked_token, predicted_token), + values[index].item(), + )) + return topk_filled_outputs