forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: facebookresearch#1014 Differential Revision: D16784120 Pulled By: myleott fbshipit-source-id: 946c0e33b594f8378e4ab6482ce49efcb36e1743
- Loading branch information
1 parent
a171c2d
commit a33ac06
Showing
14 changed files
with
387 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Finetuning RoBERTa on Commonsense QA | ||
|
||
We follow a similar approach to [finetuning RACE](README.race.md). Specifically | ||
for each question we construct five inputs, one for each of the five candidate | ||
answer choices. Each input is constructed by concatenating the question and | ||
candidate answer. We then encode each input and pass the resulting "[CLS]" | ||
representations through a fully-connected layer to predict the correct answer. | ||
We train with a standard cross-entropy loss. | ||
|
||
We also found it helpful to prepend a prefix of `Q:` to the question and `A:` to | ||
the input. The complete input format is: | ||
``` | ||
<s> Q: Where would I not want a fox? </s> A: hen house </s> | ||
``` | ||
|
||
Our final submission is based on a hyperparameter search over the learning rate | ||
(1e-5, 2e-5, 3e-5), batch size (8, 16), number of training steps (2000, 3000, | ||
4000) and random seed. We selected the model with the best performance on the | ||
development set after 100 trials. | ||
|
||
### 1) Download the data from Commonsense QA website (https://www.tau-nlp.org/commonsenseqa) | ||
```bash | ||
bash examples/roberta/commonsense_qa/download_cqa_data.sh | ||
``` | ||
|
||
### 2) Finetune | ||
|
||
```bash | ||
MAX_UPDATES=3000 # Number of training steps. | ||
WARMUP_UPDATES=150 # Linearly increase LR over this many steps. | ||
LR=1e-05 # Peak LR for polynomial LR scheduler. | ||
MAX_SENTENCES=16 # Batch size. | ||
SEED=1 # Random seed. | ||
ROBERTA_PATH=/path/to/roberta/model.pt | ||
DATA_DIR=data/CommonsenseQA | ||
|
||
# we use the --user-dir option to load the task from | ||
# the examples/roberta/commonsense_qa directory: | ||
FAIRSEQ_PATH=/path/to/fairseq | ||
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/commonsense_qa | ||
|
||
CUDA_VISIBLE_DEVICES=0 fairseq-train --fp16 \ | ||
$DATA_DIR \ | ||
--user-dir $FAIRSEQ_USER_DIR \ | ||
--restore-file $ROBERTA_PATH \ | ||
--reset-optimizer --reset-dataloader --reset-meters \ | ||
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \ | ||
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ | ||
--task commonsense_qa --init-token 0 --bpe gpt2 \ | ||
--arch roberta_large --max-positions 512 \ | ||
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ | ||
--criterion sentence_ranking --num-classes 5 \ | ||
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 --clip-norm 0.0 \ | ||
--lr-scheduler polynomial_decay --lr $LR | ||
--warmup-updates $WARMUP_UPDATES --total-num-update $MAX_UPDATES \ | ||
--max-sentences $MAX_SENTENCES \ | ||
--max-update $MAX_UPDATES \ | ||
--log-format simple --log-interval 25 \ | ||
--seed $SEED | ||
``` | ||
|
||
The above command assumes training on 1 GPU with 32GB of RAM. For GPUs with | ||
less memory, decrease `--max-sentences` and increase `--update-freq` | ||
accordingly to compensate. | ||
|
||
### 3) Evaluate | ||
```python | ||
import json | ||
import torch | ||
from fairseq.models.roberta import RobertaModel | ||
from examples.roberta import commonsense_qa # load the Commonsense QA task | ||
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'data/CommonsenseQA') | ||
roberta.eval() # disable dropout | ||
roberta.cuda() # use the GPU (optional) | ||
nsamples, ncorrect = 0, 0 | ||
with open('data/CommonsenseQA/valid.jsonl') as h: | ||
for line in h: | ||
example = json.loads(line) | ||
scores = [] | ||
for choice in example['question']['choices']: | ||
input = roberta.encode( | ||
'Q: ' + example['question']['stem'], | ||
'A: ' + choice['text'], | ||
no_separator=True | ||
) | ||
score = roberta.predict('sentence_classification_head', input, return_logits=True) | ||
scores.append(score) | ||
pred = torch.cat(scores).argmax() | ||
answer = ord(example['answerKey']) - ord('A') | ||
nsamples += 1 | ||
if pred == answer: | ||
ncorrect += 1 | ||
|
||
print('Accuracy: ' + str(ncorrect / float(nsamples))) | ||
# Accuracy: 0.7846027846027847 | ||
``` | ||
|
||
The above snippet is not batched, which makes it quite slow. See [instructions | ||
for batched prediction with RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta#batched-prediction). |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from . import commonsense_qa_task # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import json | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from fairseq.data import ( | ||
data_utils, | ||
Dictionary, | ||
encoders, | ||
IdDataset, | ||
ListDataset, | ||
NestedDictionaryDataset, | ||
NumSamplesDataset, | ||
NumelDataset, | ||
RawLabelDataset, | ||
RightPadDataset, | ||
SortDataset, | ||
) | ||
from fairseq.tasks import FairseqTask, register_task | ||
|
||
|
||
@register_task('commonsense_qa') | ||
class CommonsenseQATask(FairseqTask): | ||
"""Task to finetune RoBERTa for Commonsense QA.""" | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
"""Add task-specific arguments to the parser.""" | ||
parser.add_argument('data', metavar='DIR', | ||
help='path to data directory; we load <split>.jsonl') | ||
parser.add_argument('--init-token', type=int, default=None, | ||
help='add token at the beginning of each batch item') | ||
parser.add_argument('--num-classes', type=int, default=5) | ||
|
||
def __init__(self, args, vocab): | ||
super().__init__(args) | ||
self.vocab = vocab | ||
self.mask = vocab.add_symbol('<mask>') | ||
|
||
self.bpe = encoders.build_bpe(args) | ||
|
||
@classmethod | ||
def load_dictionary(cls, filename): | ||
"""Load the dictionary from the filename | ||
Args: | ||
filename (str): the filename | ||
""" | ||
dictionary = Dictionary.load(filename) | ||
dictionary.add_symbol('<mask>') | ||
return dictionary | ||
|
||
@classmethod | ||
def setup_task(cls, args, **kwargs): | ||
assert args.criterion == 'sentence_ranking', 'Must set --criterion=sentence_ranking' | ||
|
||
# load data and label dictionaries | ||
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt')) | ||
print('| dictionary: {} types'.format(len(vocab))) | ||
|
||
return cls(args, vocab) | ||
|
||
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): | ||
"""Load a given dataset split. | ||
Args: | ||
split (str): name of the split (e.g., train, valid, test) | ||
""" | ||
|
||
def binarize(s, append_bos=False): | ||
if self.bpe is not None: | ||
s = self.bpe.encode(s) | ||
tokens = self.vocab.encode_line( | ||
s, append_eos=True, add_if_not_exist=False, | ||
).long() | ||
if append_bos and self.args.init_token is not None: | ||
tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) | ||
return tokens | ||
|
||
if data_path is None: | ||
data_path = os.path.join(self.args.data, split + '.jsonl') | ||
if not os.path.exists(data_path): | ||
raise FileNotFoundError('Cannot find data: {}'.format(data_path)) | ||
|
||
src_tokens = [[] for i in range(self.args.num_classes)] | ||
src_lengths = [[] for i in range(self.args.num_classes)] | ||
labels = [] | ||
|
||
with open(data_path) as h: | ||
for line in h: | ||
example = json.loads(line.strip()) | ||
if 'answerKey' in example: | ||
label = ord(example['answerKey']) - ord('A') | ||
labels.append(label) | ||
question = example['question']['stem'] | ||
assert len(example['question']['choices']) == self.args.num_classes | ||
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` | ||
question = 'Q: ' + question | ||
question_toks = binarize(question, append_bos=True) | ||
for i, choice in enumerate(example['question']['choices']): | ||
src = 'A: ' + choice['text'] | ||
src_bin = torch.cat([question_toks, binarize(src)]) | ||
src_tokens[i].append(src_bin) | ||
src_lengths[i].append(len(src_bin)) | ||
assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) | ||
assert len(src_tokens[0]) == len(src_lengths[0]) | ||
assert len(labels) == 0 or len(labels) == len(src_tokens[0]) | ||
|
||
for i in range(self.args.num_classes): | ||
src_lengths[i] = np.array(src_lengths[i]) | ||
src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) | ||
src_lengths[i] = ListDataset(src_lengths[i]) | ||
|
||
dataset = { | ||
'id': IdDataset(), | ||
'nsentences': NumSamplesDataset(), | ||
'ntokens': NumelDataset(src_tokens[0], reduce=True), | ||
} | ||
|
||
for i in range(self.args.num_classes): | ||
dataset.update({ | ||
'net_input{}'.format(i + 1): { | ||
'src_tokens': RightPadDataset( | ||
src_tokens[i], | ||
pad_idx=self.source_dictionary.pad(), | ||
), | ||
'src_lengths': src_lengths[i], | ||
} | ||
}) | ||
|
||
if len(labels) > 0: | ||
dataset.update({'target': RawLabelDataset(labels)}) | ||
|
||
dataset = NestedDictionaryDataset( | ||
dataset, | ||
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], | ||
) | ||
|
||
with data_utils.numpy_seed(self.args.seed): | ||
dataset = SortDataset( | ||
dataset, | ||
# shuffle | ||
sort_order=[np.random.permutation(len(dataset))], | ||
) | ||
|
||
print('| Loaded {} with {} samples'.format(split, len(dataset))) | ||
|
||
self.datasets[split] = dataset | ||
return self.datasets[split] | ||
|
||
def build_model(self, args): | ||
from fairseq import models | ||
model = models.build_model(args, self) | ||
|
||
model.register_classification_head( | ||
'sentence_classification_head', | ||
num_classes=1, | ||
) | ||
|
||
return model | ||
|
||
@property | ||
def source_dictionary(self): | ||
return self.vocab | ||
|
||
@property | ||
def target_dictionary(self): | ||
return self.vocab |
Oops, something went wrong.