Skip to content

Commit

Permalink
Add instructions to load RoBERTa models on PyTorch 1.0
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#921

Differential Revision: D16541025

Pulled By: myleott

fbshipit-source-id: bb78d30fe285da2adfc7c4e5897ee01fa413b2e4
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Jul 29, 2019
1 parent 8d036c2 commit ce7f044
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,47 @@ Model | Description | # params | Download
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
`roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)

## Example usage (torch.hub)
## Results

##### Results on GLUE tasks (dev set, single model, single-task finetuning)

##### Load RoBERTa:
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -

##### Results on SQuAD (dev set)

Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4

##### Results on Reading Comprehension (RACE, test set)

Model | Accuracy | Middle | High
---|---|---|---
`roberta.large` | 83.2 | 86.5 | 81.3

## Example usage

##### Load RoBERTa from torch.hub (PyTorch >= 1.1):
```
>>> import torch
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
```

##### Load RoBERTa (for PyTorch 1.0):
```
$ wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
$ tar -xzvf roberta.large.tar.gz
>>> from fairseq.models.roberta import RobertaModel
>>> roberta = RobertaModel.from_pretrained('/path/to/roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
```

##### Apply Byte-Pair Encoding (BPE) to input text:
```
>>> tokens = roberta.encode('Hello world!')
Expand Down Expand Up @@ -80,29 +112,7 @@ tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
```

## Results

##### Results on GLUE tasks (dev set, single model, single-task finetuning)

Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -

##### Results on SQuAD (dev set)

Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4

##### Results on Reading Comprehension (RACE, test set)

Model | Accuracy | Middle | High
---|---|---|---
`roberta.large` | 83.2 | 86.5 | 81.3

## Evaluating the `roberta.large.mnli` model
##### Evaluating the `roberta.large.mnli` model

Example python code snippet to evaluate accuracy on the MNLI dev_matched set.
```
Expand Down

0 comments on commit ce7f044

Please sign in to comment.