Skip to content

Commit

Permalink
added effcient wsc task/criterion for winogrande (facebookresearch#825)
Browse files Browse the repository at this point in the history
Summary:
1) So far getting `78%`  on winogrande validation dataset comapred to `63.5%` in the paper.
2) Will upgrade readme once everything is finalized.

Questions:

1) Should I just call `binary_wsc_task` instead of `winogrande` to be less specific to dataset and be generic?
Pull Request resolved: fairinternal/fairseq-py#825

Differential Revision: D16810159

fbshipit-source-id: cfde73561fa4caaaa63a4773c0aecd12ce1fa518
  • Loading branch information
ngoyal2707 authored and facebook-github-bot committed Aug 15, 2019
1 parent f840564 commit 1d44cc8
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 58 deletions.
8 changes: 4 additions & 4 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Model | Description | # params | Download
`roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
`roberta.large.mnli` | `roberta.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
`roberta.large.wsc` | `roberta.large` finetuned on [WSC](README.wsc.md) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz)
`roberta.large.wsc` | `roberta.large` finetuned on [WSC](wsc/README.md) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz)

### Results

Expand Down Expand Up @@ -168,7 +168,7 @@ roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a pe
# 'demonstrators'
```

See the [RoBERTA Winograd Schema Challenge (WSC) README](README.wsc.md) for more details on how to train this model.
See the [RoBERTA Winograd Schema Challenge (WSC) README](wsc/README.md) for more details on how to train this model.

#### Extract features aligned to words:

Expand Down Expand Up @@ -220,8 +220,8 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples))

- [Finetuning on GLUE](README.glue.md)
- [Finetuning on custom classification tasks (e.g., IMDB)](README.custom_classification.md)
- [Finetuning on Winograd Schema Challenge (WSC)](README.wsc.md)
- [Finetuning on Commonsense QA (CQA)](README.cqa.md)
- [Finetuning on Winograd Schema Challenge (WSC)](wsc/README.md)
- [Finetuning on Commonsense QA (CQA)](commonsense_qa/README.md)
- Finetuning on SQuAD: coming soon

### Pretraining using your own data
Expand Down
File renamed without changes.
40 changes: 40 additions & 0 deletions examples/roberta/README.wsc.md → examples/roberta/wsc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,43 @@ for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True):
print('Accuracy: ' + str(ncorrect / float(nsamples)))
# Accuracy: 0.9230769230769231
```

## RoBERTa training on WinoGrande dataset
We have also provided `winogrande` task and criterion for finetuning on the
[WinoGrande](https://mosaic.allenai.org/projects/winogrande) like datasets
where there are always two candidates and one is correct.
It's more efficient implementation for such subcases.

```bash
TOTAL_NUM_UPDATES=23750 # Total number of training steps.
WARMUP_UPDATES=2375 # Linearly increase LR over this many steps.
LR=1e-05 # Peak LR for polynomial LR scheduler.
MAX_SENTENCES=32 # Batch size per GPU.
SEED=1 # Random seed.
ROBERTA_PATH=/path/to/roberta/model.pt

# we use the --user-dir option to load the task and criterion
# from the examples/roberta/wsc directory:
FAIRSEQ_PATH=/path/to/fairseq
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc

cd fairseq
CUDA_VISIBLE_DEVICES=0 fairseq-train winogrande_1.0/ \
--restore-file $ROBERTA_PATH \
--reset-optimizer --reset-dataloader --reset-meters \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
--fp16 --ddp-backend no_c10d \
--user-dir $FAIRSEQ_USER_DIR \
--task winogrande --criterion winogrande \
--wsc-margin-alpha 5.0 --wsc-margin-beta 0.4 \
--arch roberta_large --bpe gpt2 --max-positions 512 \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
--max-sentences $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100
```
91 changes: 63 additions & 28 deletions examples/roberta/wsc/wsc_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,46 @@ def add_args(parser):
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')

def forward(self, model, sample, reduce=True):

def get_masked_input(tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask] = self.task.mask
return masked_tokens

def get_lprobs(tokens, mask):
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
def get_masked_input(self, tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask] = self.task.mask
return masked_tokens

def get_lprobs(self, model, tokens, mask):
logits, _ = model(src_tokens=self.get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores

def get_loss(self, query_lprobs, cand_lprobs):
if self.args.wsc_cross_entropy:
return F.cross_entropy(
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
query_lprobs.new([0]).long(),
)
else:
return (
- query_lprobs
+ self.args.wsc_margin_alpha * (
cand_lprobs - query_lprobs + self.args.wsc_margin_beta
).clamp(min=0)
).sum()

def forward(self, model, sample, reduce=True):
# compute loss and accuracy
loss, nloss = 0., 0
ncorrect, nqueries = 0, 0

for i, label in enumerate(sample['labels']):
query_lprobs = get_lprobs(
query_lprobs = self.get_lprobs(
model,
sample['query_tokens'][i].unsqueeze(0),
sample['query_masks'][i].unsqueeze(0),
)
cand_lprobs = get_lprobs(
cand_lprobs = self.get_lprobs(
model,
sample['candidate_tokens'][i],
sample['candidate_masks'][i],
)
Expand All @@ -77,18 +93,7 @@ def get_lprobs(tokens, mask):
if label:
# only compute a loss for positive instances
nloss += 1
if self.args.wsc_cross_entropy:
loss += F.cross_entropy(
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
query_lprobs.new([0]).long(),
)
else:
loss += (
- query_lprobs
+ self.args.wsc_margin_alpha * (
cand_lprobs - query_lprobs + self.args.wsc_margin_beta
).clamp(min=0)
).sum()
loss += self.get_loss(query_lprobs, cand_lprobs)

id = sample['id'][i].item()
if self.prediction_h is not None:
Expand Down Expand Up @@ -129,3 +134,33 @@ def aggregate_logging_outputs(logging_outputs):
agg_output['accuracy'] = ncorrect / float(nqueries)

return agg_output


@register_criterion('winogrande')
class WinograndeCriterion(WSCCriterion):
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
query_lprobs = self.get_lprobs(
model,
sample['query_tokens'],
sample['query_masks'],
)
cand_lprobs = self.get_lprobs(
model,
sample['candidate_tokens'],
sample['candidate_masks'],
)
pred = query_lprobs >= cand_lprobs
loss = self.get_loss(query_lprobs, cand_lprobs)

sample_size = sample['query_tokens'].size(0)
ncorrect = pred.sum().item()
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
'ncorrect': ncorrect,
'nqueries': sample_size,
}
return loss, sample_size, logging_output
167 changes: 141 additions & 26 deletions examples/roberta/wsc/wsc_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
PadDataset,
SortDataset,
)
from fairseq.tasks import FairseqTask, register_task
Expand Down Expand Up @@ -77,25 +78,35 @@ def setup_task(cls, args, **kwargs):

return cls(args, vocab)

def binarize(self, s: str, append_eos: bool = False):
if self.tokenizer is not None:
s = self.tokenizer.encode(s)
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s, append_eos=append_eos, add_if_not_exist=False,
).long()
if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens

def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space):
toks = self.binarize(
prefix + leading_space + txt + trailing_space + suffix,
append_eos=True,
)
mask = torch.zeros_like(toks, dtype=torch.uint8)
mask_start = len(self.binarize(prefix))
mask_size = len(self.binarize(leading_space + txt))
mask[mask_start:mask_start + mask_size] = 1
return toks, mask

def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""

def binarize(s: str, append_eos: bool = False):
if self.tokenizer is not None:
s = self.tokenizer.encode(s)
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s, append_eos=append_eos, add_if_not_exist=False,
).long()
if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens

if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl')
if not os.path.exists(data_path):
Expand Down Expand Up @@ -126,19 +137,10 @@ def binarize(s: str, append_eos: bool = False):
exact_match=False,
)

def binarize_with_mask(txt):
toks = binarize(
prefix + leading_space + txt + trailing_space + suffix,
append_eos=True,
)
mask = torch.zeros_like(toks, dtype=torch.uint8)
mask_start = len(binarize(prefix))
mask_size = len(binarize(leading_space + txt))
mask[mask_start:mask_start + mask_size] = 1
return toks, mask

if query is not None:
query_toks, query_mask = binarize_with_mask(query)
query_toks, query_mask = self.binarize_with_mask(
query, prefix, suffix, leading_space, trailing_space
)
query_len = len(query_toks)
else:
query_toks, query_mask, query_len = None, None, 0
Expand All @@ -149,7 +151,9 @@ def binarize_with_mask(txt):

cand_toks, cand_masks = [], []
for cand_span in cand_spans:
toks, mask = binarize_with_mask(cand_span.text)
toks, mask = self.binarize_with_mask(
cand_span.text, prefix, suffix, leading_space, trailing_space,
)
cand_toks.append(toks)
cand_masks.append(mask)

Expand Down Expand Up @@ -258,3 +262,114 @@ def source_dictionary(self):
@property
def target_dictionary(self):
return self.vocab


@register_task('winogrande')
class WinograndeTask(WSCTask):
"""
Task for WinoGrande dataset. Efficient implementation for Winograd schema
tasks with exactly two candidates, one of which is correct.
"""
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == 'winogrande', 'Must set --criterion=winogrande'

# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(vocab)))

return cls(args, vocab)


def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
data_path = os.path.join(self.args.data, split + '.jsonl')
if not os.path.exists(data_path):
raise FileNotFoundError('Cannot find data: {}'.format(data_path))

query_tokens = []
query_masks = []
query_lengths = []
candidate_tokens = []
candidate_masks = []
candidate_lengths = []

itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=split=='test')

for sample in itr:
sentence, pronoun_span, query, cand_text = sample
prefix = sentence[:pronoun_span[0]].rstrip()
suffix = sentence[pronoun_span[1]:]

leading_space = ' ' if sentence[:pronoun_span[0]].endswith(' ') else ''
trailing_space = ''

if query is not None:
query_toks, query_mask = self.binarize_with_mask(
query, prefix, suffix, leading_space, trailing_space,
)
query_len = len(query_toks)
else:
query_toks, query_mask, query_len = None, None, 0

query_tokens.append(query_toks)
query_masks.append(query_mask)
query_lengths.append(query_len)

cand_toks, cand_mask = self.binarize_with_mask(
cand_text, prefix, suffix, leading_space, trailing_space,
)

candidate_tokens.append(cand_toks)
candidate_masks.append(cand_mask)
candidate_lengths.append(cand_toks.size(0))

query_lengths = np.array(query_lengths)

def get_pad_dataset_fn(tokens, length, pad_idx):
return PadDataset(
ListDataset(tokens, length),
pad_idx=pad_idx,
left_pad=False,
)

query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad())
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)

candidate_lengths = np.array(candidate_lengths)
candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad())
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)

dataset = {
'id': IdDataset(),
'query_tokens': query_tokens,
'query_masks': query_masks,
'candidate_tokens': candidate_tokens,
'candidate_masks': candidate_masks,
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(query_tokens, reduce=True),
}

nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[query_lengths],
)

with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(query_tokens))
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)

if return_only:
return dataset

self.datasets[split] = dataset
return self.datasets[split]
Loading

0 comments on commit 1d44cc8

Please sign in to comment.