Skip to content

Commit

Permalink
Add Commonsense QA task
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#1014

Differential Revision: D16784120

Pulled By: myleott

fbshipit-source-id: 946c0e33b594f8378e4ab6482ce49efcb36e1743
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Aug 13, 2019
1 parent a171c2d commit a33ac06
Show file tree
Hide file tree
Showing 14 changed files with 387 additions and 59 deletions.
46 changes: 23 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ modeling and other text generation tasks.

Fairseq provides reference implementations of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- **_New_** [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
- [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
- [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
- **LightConv and DynamicConv models**
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
- **Long Short-Term Memory (LSTM) networks**
- Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation
- Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
- **Transformer (self-attention) networks**
- Vaswani et al. (2017): Attention Is All You Need
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md)
- [Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](examples/language_model/transformer_lm/README.md)
- [Shen et al. (2019): Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- **_New_** [Liu et al. (2019): RoBERTa: A Robustly Optimized BERT Pretraining Approach](examples/roberta/README.md)
- Attention Is All You Need (Vaswani et al., 2017)
- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
- [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/transformer_lm/README.md)
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)

**Additionally:**
- multi-GPU (distributed) training on one machine or across multiple machines
Expand Down Expand Up @@ -96,16 +96,16 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional models are available

We also have more detailed READMEs to reproduce results from specific papers:
- [Liu et al. (2019): RoBERTa: A Robustly Optimized BERT Pretraining Approach](examples/roberta/README.md)
- [Schneider et al. (2019): wav2vec: Unsupervised Pre-training for Speech Recognition](examples/wav2vec/README.md)
- [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
- [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
- [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)

# Join the fairseq community

Expand Down
99 changes: 99 additions & 0 deletions examples/roberta/README.cqa.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Finetuning RoBERTa on Commonsense QA

We follow a similar approach to [finetuning RACE](README.race.md). Specifically
for each question we construct five inputs, one for each of the five candidate
answer choices. Each input is constructed by concatenating the question and
candidate answer. We then encode each input and pass the resulting "[CLS]"
representations through a fully-connected layer to predict the correct answer.
We train with a standard cross-entropy loss.

We also found it helpful to prepend a prefix of `Q:` to the question and `A:` to
the input. The complete input format is:
```
<s> Q: Where would I not want a fox? </s> A: hen house </s>
```

Our final submission is based on a hyperparameter search over the learning rate
(1e-5, 2e-5, 3e-5), batch size (8, 16), number of training steps (2000, 3000,
4000) and random seed. We selected the model with the best performance on the
development set after 100 trials.

### 1) Download the data from Commonsense QA website (https://www.tau-nlp.org/commonsenseqa)
```bash
bash examples/roberta/commonsense_qa/download_cqa_data.sh
```

### 2) Finetune

```bash
MAX_UPDATES=3000 # Number of training steps.
WARMUP_UPDATES=150 # Linearly increase LR over this many steps.
LR=1e-05 # Peak LR for polynomial LR scheduler.
MAX_SENTENCES=16 # Batch size.
SEED=1 # Random seed.
ROBERTA_PATH=/path/to/roberta/model.pt
DATA_DIR=data/CommonsenseQA

# we use the --user-dir option to load the task from
# the examples/roberta/commonsense_qa directory:
FAIRSEQ_PATH=/path/to/fairseq
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/commonsense_qa

CUDA_VISIBLE_DEVICES=0 fairseq-train --fp16 \
$DATA_DIR \
--user-dir $FAIRSEQ_USER_DIR \
--restore-file $ROBERTA_PATH \
--reset-optimizer --reset-dataloader --reset-meters \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--task commonsense_qa --init-token 0 --bpe gpt2 \
--arch roberta_large --max-positions 512 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--criterion sentence_ranking --num-classes 5 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 --clip-norm 0.0 \
--lr-scheduler polynomial_decay --lr $LR
--warmup-updates $WARMUP_UPDATES --total-num-update $MAX_UPDATES \
--max-sentences $MAX_SENTENCES \
--max-update $MAX_UPDATES \
--log-format simple --log-interval 25 \
--seed $SEED
```

The above command assumes training on 1 GPU with 32GB of RAM. For GPUs with
less memory, decrease `--max-sentences` and increase `--update-freq`
accordingly to compensate.

### 3) Evaluate
```python
import json
import torch
from fairseq.models.roberta import RobertaModel
from examples.roberta import commonsense_qa # load the Commonsense QA task
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'data/CommonsenseQA')
roberta.eval() # disable dropout
roberta.cuda() # use the GPU (optional)
nsamples, ncorrect = 0, 0
with open('data/CommonsenseQA/valid.jsonl') as h:
for line in h:
example = json.loads(line)
scores = []
for choice in example['question']['choices']:
input = roberta.encode(
'Q: ' + example['question']['stem'],
'A: ' + choice['text'],
no_separator=True
)
score = roberta.predict('sentence_classification_head', input, return_logits=True)
scores.append(score)
pred = torch.cat(scores).argmax()
answer = ord(example['answerKey']) - ord('A')
nsamples += 1
if pred == answer:
ncorrect += 1

print('Accuracy: ' + str(ncorrect / float(nsamples)))
# Accuracy: 0.7846027846027847
```

The above snippet is not batched, which makes it quite slow. See [instructions
for batched prediction with RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta#batched-prediction).
File renamed without changes.
5 changes: 3 additions & 2 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,10 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples))

## Finetuning

- [Finetuning on GLUE](README.finetune_glue.md)
- [Finetuning on custom classification tasks (e.g., IMDB)](README.finetune_custom_classification.md)
- [Finetuning on GLUE](README.glue.md)
- [Finetuning on custom classification tasks (e.g., IMDB)](README.custom_classification.md)
- [Finetuning on Winograd Schema Challenge (WSC)](README.wsc.md)
- [Finetuning on Commonsense QA (CQA)](README.cqa.md)
- Finetuning on SQuAD: coming soon

## Pretraining using your own data
Expand Down
File renamed without changes.
34 changes: 17 additions & 17 deletions examples/roberta/README.wsc.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,24 @@ ROBERTA_PATH=/path/to/roberta/model.pt
FAIRSEQ_PATH=/path/to/fairseq
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc

cd fairseq
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
--restore-file $ROBERTA_PATH \
--reset-optimizer --reset-dataloader --reset-meters \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
--fp16 --ddp-backend no_c10d \
--user-dir $FAIRSEQ_USER_DIR \
--task wsc --criterion wsc --wsc-cross-entropy \
--arch roberta_large --bpe gpt2 --max-positions 512 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
--max-sentences $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100
--restore-file $ROBERTA_PATH \
--reset-optimizer --reset-dataloader --reset-meters \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
--fp16 --ddp-backend no_c10d \
--user-dir $FAIRSEQ_USER_DIR \
--task wsc --criterion wsc --wsc-cross-entropy \
--arch roberta_large --bpe gpt2 --max-positions 512 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
--max-sentences $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100 \
--seed $SEED
```

The above command assumes training on 4 GPUs, but you can achieve the same
Expand Down
6 changes: 6 additions & 0 deletions examples/roberta/commonsense_qa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import commonsense_qa_task # noqa
174 changes: 174 additions & 0 deletions examples/roberta/commonsense_qa/commonsense_qa_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os

import numpy as np
import torch

from fairseq.data import (
data_utils,
Dictionary,
encoders,
IdDataset,
ListDataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
)
from fairseq.tasks import FairseqTask, register_task


@register_task('commonsense_qa')
class CommonsenseQATask(FairseqTask):
"""Task to finetune RoBERTa for Commonsense QA."""

@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR',
help='path to data directory; we load <split>.jsonl')
parser.add_argument('--init-token', type=int, default=None,
help='add token at the beginning of each batch item')
parser.add_argument('--num-classes', type=int, default=5)

def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
self.mask = vocab.add_symbol('<mask>')

self.bpe = encoders.build_bpe(args)

@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol('<mask>')
return dictionary

@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == 'sentence_ranking', 'Must set --criterion=sentence_ranking'

# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(vocab)))

return cls(args, vocab)

def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""

def binarize(s, append_bos=False):
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s, append_eos=True, add_if_not_exist=False,
).long()
if append_bos and self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens

if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl')
if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path))

src_tokens = [[] for i in range(self.args.num_classes)]
src_lengths = [[] for i in range(self.args.num_classes)]
labels = []

with open(data_path) as h:
for line in h:
example = json.loads(line.strip())
if 'answerKey' in example:
label = ord(example['answerKey']) - ord('A')
labels.append(label)
question = example['question']['stem']
assert len(example['question']['choices']) == self.args.num_classes
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
question = 'Q: ' + question
question_toks = binarize(question, append_bos=True)
for i, choice in enumerate(example['question']['choices']):
src = 'A: ' + choice['text']
src_bin = torch.cat([question_toks, binarize(src)])
src_tokens[i].append(src_bin)
src_lengths[i].append(len(src_bin))
assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes))
assert len(src_tokens[0]) == len(src_lengths[0])
assert len(labels) == 0 or len(labels) == len(src_tokens[0])

for i in range(self.args.num_classes):
src_lengths[i] = np.array(src_lengths[i])
src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i])
src_lengths[i] = ListDataset(src_lengths[i])

dataset = {
'id': IdDataset(),
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_tokens[0], reduce=True),
}

for i in range(self.args.num_classes):
dataset.update({
'net_input{}'.format(i + 1): {
'src_tokens': RightPadDataset(
src_tokens[i],
pad_idx=self.source_dictionary.pad(),
),
'src_lengths': src_lengths[i],
}
})

if len(labels) > 0:
dataset.update({'target': RawLabelDataset(labels)})

dataset = NestedDictionaryDataset(
dataset,
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
)

with data_utils.numpy_seed(self.args.seed):
dataset = SortDataset(
dataset,
# shuffle
sort_order=[np.random.permutation(len(dataset))],
)

print('| Loaded {} with {} samples'.format(split, len(dataset)))

self.datasets[split] = dataset
return self.datasets[split]

def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)

model.register_classification_head(
'sentence_classification_head',
num_classes=1,
)

return model

@property
def source_dictionary(self):
return self.vocab

@property
def target_dictionary(self):
return self.vocab
Loading

0 comments on commit a33ac06

Please sign in to comment.