diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 1b8d637ccb..68dc6701ea 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -146,11 +146,26 @@ logprobs = roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1. ##### Batched prediction: ```python +import torch from fairseq.data.data_utils import collate_tokens -sentences = ['Hello world.', 'Another unrelated sentence.'] -batch = collate_tokens([roberta.encode(sent) for sent in sentences], pad_idx=1) -logprobs = roberta.predict('new_task', batch) -assert logprobs.size() == torch.Size([2, 3]) + +roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') +roberta.eval() + +batch_of_pairs = [ + ['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'], + ['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'], + ['potatoes are awesome.', 'I like to run.'], + ['Mars is very far from earth.', 'Mars is very close.'], +] + +batch = collate_tokens( + [roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1 +) + +logprobs = roberta.predict('mnli', batch) +print(logprobs.argmax(dim=1)) +# tensor([0, 2, 1, 0]) ``` ##### Using the GPU: