Skip to content

Commit

Permalink
Fixing example of batched predictions for Roberta (facebookresearch#1195
Browse files Browse the repository at this point in the history
)

Summary:
For batched predictions in Roberta, the README was giving an example that was pretty unclear. After a thorough discussion with ngoyal2707 in issue facebookresearch#1167 he gave a clear example of how batched predictions were supposed to be done. Since I spent a lot of time on this inconsistency, I thought that it might benefit the community if his solution was in the official README 😄 !

For for details, see issue facebookresearch#1167
Pull Request resolved: facebookresearch#1195

Differential Revision: D17639354

Pulled By: myleott

fbshipit-source-id: 3eb60c5804a6481f533b19073da7880dfd0d522d
  • Loading branch information
justachetan authored and facebook-github-bot committed Sep 27, 2019
1 parent 86857a5 commit 1cb267e
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,26 @@ logprobs = roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.

##### Batched prediction:
```python
import torch
from fairseq.data.data_utils import collate_tokens
sentences = ['Hello world.', 'Another unrelated sentence.']
batch = collate_tokens([roberta.encode(sent) for sent in sentences], pad_idx=1)
logprobs = roberta.predict('new_task', batch)
assert logprobs.size() == torch.Size([2, 3])

roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
roberta.eval()

batch_of_pairs = [
['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'],
['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'],
['potatoes are awesome.', 'I like to run.'],
['Mars is very far from earth.', 'Mars is very close.'],
]

batch = collate_tokens(
[roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
)

logprobs = roberta.predict('mnli', batch)
print(logprobs.argmax(dim=1))
# tensor([0, 2, 1, 0])
```

##### Using the GPU:
Expand Down

0 comments on commit 1cb267e

Please sign in to comment.