forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: facebookresearch#1004 Differential Revision: D16751443 Pulled By: myleott fbshipit-source-id: f70acd6c7be6d69da45b5b32fe4c4eff021539ab
- Loading branch information
1 parent
a00ce13
commit 8324919
Showing
17 changed files
with
848 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Finetuning RoBERTa on Winograd Schema Challenge (WSC) data | ||
|
||
The following instructions can be used to finetune RoBERTa on the WSC training | ||
data provided by [SuperGLUE](https://super.gluebenchmark.com/). | ||
|
||
Note that there is high variance in the results. For our GLUE/SuperGLUE | ||
submission we swept over the learning rate, batch size and total number of | ||
updates, as well as the random seed. Out of ~100 runs we chose the best 7 models | ||
and ensembled them. | ||
|
||
**Note:** The instructions below use a slightly different loss function than | ||
what's described in the original RoBERTa arXiv paper. In particular, | ||
[Kocijan et al. (2019)](https://arxiv.org/abs/1905.06290) introduce a margin | ||
ranking loss between `(query, candidate)` pairs with tunable hyperparameters | ||
alpha and beta. This is supported in our code as well with the `--wsc-alpha` and | ||
`--wsc-beta` arguments. However, we achieved slightly better (and more robust) | ||
results on the development set by instead using a single cross entropy loss term | ||
over the log-probabilities for the query and all candidates. This reduces the | ||
number of hyperparameters and our best model achieved 92.3% development set | ||
accuracy, compared to ~90% accuracy for the margin loss. Later versions of the | ||
RoBERTa arXiv paper will describe this updated formulation. | ||
|
||
### 1) Download the WSC data from the SuperGLUE website: | ||
```bash | ||
wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip | ||
unzip WSC.zip | ||
|
||
# we also need to copy the RoBERTa dictionary into the same directory | ||
wget -O WSC/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt | ||
``` | ||
|
||
### 2) Finetune over the provided training data: | ||
```bash | ||
TOTAL_NUM_UPDATES=2000 # Total number of training steps. | ||
WARMUP_UPDATES=250 # Linearly increase LR over this many steps. | ||
LR=2e-05 # Peak LR for polynomial LR scheduler. | ||
MAX_SENTENCES=16 # Batch size per GPU. | ||
SEED=1 # Random seed. | ||
ROBERTA_PATH=/path/to/roberta/model.pt | ||
|
||
# we use the --user-dir option to load the task and criterion | ||
# from the examples/roberta/wsc directory: | ||
FAIRSEQ_PATH=/path/to/fairseq | ||
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc | ||
|
||
cd fairseq | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \ | ||
--restore-file $ROBERTA_PATH \ | ||
--reset-optimizer --reset-dataloader --reset-meters \ | ||
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \ | ||
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ | ||
--valid-subset val \ | ||
--fp16 --ddp-backend no_c10d \ | ||
--user-dir $FAIRSEQ_USER_DIR \ | ||
--task wsc --criterion wsc --wsc-cross-entropy \ | ||
--arch roberta_large --bpe gpt2 --max-positions 512 \ | ||
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ | ||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \ | ||
--lr-scheduler polynomial_decay --lr $LR \ | ||
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \ | ||
--max-sentences $MAX_SENTENCES \ | ||
--max-update $TOTAL_NUM_UPDATES \ | ||
--log-format simple --log-interval 100 | ||
``` | ||
|
||
The above command assumes training on 4 GPUs, but you can achieve the same | ||
results on a single GPU by adding `--update-freq=4`. | ||
|
||
### 3) Evaluate | ||
```python | ||
from fairseq.models.roberta import RobertaModel | ||
from examples.roberta.wsc import wsc_utils # also loads WSC task and criterion | ||
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'WSC/') | ||
roberta.cuda() | ||
nsamples, ncorrect = 0, 0 | ||
for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True): | ||
pred = roberta.disambiguate_pronoun(sentence) | ||
nsamples += 1 | ||
if pred == label: | ||
ncorrect += 1 | ||
print('Accuracy: ' + str(ncorrect / float(nsamples))) | ||
# Accuracy: 0.9230769230769231 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from . import wsc_criterion # noqa | ||
from . import wsc_task # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from fairseq import utils | ||
from fairseq.data import encoders | ||
from fairseq.criterions import FairseqCriterion, register_criterion | ||
|
||
|
||
@register_criterion('wsc') | ||
class WSCCriterion(FairseqCriterion): | ||
|
||
def __init__(self, args, task): | ||
super().__init__(args, task) | ||
if self.args.save_predictions is not None: | ||
self.prediction_h = open(self.args.save_predictions, 'w') | ||
else: | ||
self.prediction_h = None | ||
self.bpe = encoders.build_bpe(args) | ||
self.tokenizer = encoders.build_tokenizer(args) | ||
|
||
def __del__(self): | ||
if self.prediction_h is not None: | ||
self.prediction_h.close() | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
"""Add criterion-specific arguments to the parser.""" | ||
parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0) | ||
parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0) | ||
parser.add_argument('--wsc-cross-entropy', action='store_true', | ||
help='use cross entropy formulation instead of margin loss') | ||
parser.add_argument('--save-predictions', metavar='FILE', | ||
help='file to save predictions to') | ||
|
||
def forward(self, model, sample, reduce=True): | ||
|
||
def get_masked_input(tokens, mask): | ||
masked_tokens = tokens.clone() | ||
masked_tokens[mask] = self.task.mask | ||
return masked_tokens | ||
|
||
def get_lprobs(tokens, mask): | ||
logits, _ = model(src_tokens=get_masked_input(tokens, mask)) | ||
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float) | ||
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1) | ||
mask = mask.type_as(scores) | ||
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1) | ||
return scores | ||
|
||
# compute loss and accuracy | ||
loss, nloss = 0., 0 | ||
ncorrect, nqueries = 0, 0 | ||
for i, label in enumerate(sample['labels']): | ||
query_lprobs = get_lprobs( | ||
sample['query_tokens'][i].unsqueeze(0), | ||
sample['query_masks'][i].unsqueeze(0), | ||
) | ||
cand_lprobs = get_lprobs( | ||
sample['candidate_tokens'][i], | ||
sample['candidate_masks'][i], | ||
) | ||
|
||
pred = (query_lprobs >= cand_lprobs).all().item() | ||
|
||
if label is not None: | ||
label = 1 if label else 0 | ||
ncorrect += 1 if pred == label else 0 | ||
nqueries += 1 | ||
|
||
if label: | ||
# only compute a loss for positive instances | ||
nloss += 1 | ||
if self.args.wsc_cross_entropy: | ||
loss += F.cross_entropy( | ||
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0), | ||
query_lprobs.new([0]).long(), | ||
) | ||
else: | ||
loss += ( | ||
- query_lprobs | ||
+ self.args.wsc_margin_alpha * ( | ||
cand_lprobs - query_lprobs + self.args.wsc_margin_beta | ||
).clamp(min=0) | ||
).sum() | ||
|
||
id = sample['id'][i].item() | ||
if self.prediction_h is not None: | ||
print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h) | ||
|
||
if nloss == 0: | ||
loss = torch.tensor(0.0, requires_grad=True) | ||
|
||
sample_size = nqueries if nqueries > 0 else 1 | ||
logging_output = { | ||
'loss': utils.item(loss.data) if reduce else loss.data, | ||
'ntokens': sample['ntokens'], | ||
'nsentences': sample['nsentences'], | ||
'sample_size': sample_size, | ||
'ncorrect': ncorrect, | ||
'nqueries': nqueries, | ||
} | ||
return loss, sample_size, logging_output | ||
|
||
@staticmethod | ||
def aggregate_logging_outputs(logging_outputs): | ||
"""Aggregate logging outputs from data parallel training.""" | ||
loss_sum = sum(log.get('loss', 0) for log in logging_outputs) | ||
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) | ||
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) | ||
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) | ||
|
||
agg_output = { | ||
'loss': loss_sum / sample_size / math.log(2), | ||
'ntokens': ntokens, | ||
'nsentences': nsentences, | ||
'sample_size': sample_size, | ||
} | ||
|
||
ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs) | ||
nqueries = sum(log.get('nqueries', 0) for log in logging_outputs) | ||
if nqueries > 0: | ||
agg_output['accuracy'] = ncorrect / float(nqueries) | ||
|
||
return agg_output |
Oops, something went wrong.