Skip to content

Commit

Permalink
added readme code for inference with GLUE finetuned model
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#820

Differential Revision: D16783469

fbshipit-source-id: d5af8ba6a6685608d67b72d584952b8e43eabf9f
  • Loading branch information
Naman Goyal authored and facebook-github-bot committed Aug 13, 2019
1 parent 577e4fa commit a171c2d
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions examples/roberta/README.finetune_glue.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,35 @@ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calcul
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`.

c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.

### Inference on GLUE task
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:

```python
from fairseq.models.roberta import RobertaModel

roberta = RobertaModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='RTE-bin'
)

label_fn = lambda label: roberta.task.label_dictionary.string(
[label + roberta.task.target_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
roberta.cuda()
roberta.eval()
with open('glue_data/RTE/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
tokens = roberta.encode(sent1, sent2)
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))

```

0 comments on commit a171c2d

Please sign in to comment.