diff --git a/examples/roberta/README.md b/examples/roberta/README.md
index 5f3be7941d..5b80fe94cf 100644
--- a/examples/roberta/README.md
+++ b/examples/roberta/README.md
@@ -12,7 +12,8 @@ Model | Description | # params | Download
---|---|---|---
`roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
-`roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
+`roberta.large.mnli` | `roberta.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
+`roberta.large.wsc` | `roberta.large` finetuned on [WSC](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz)
## Results
@@ -24,12 +25,12 @@ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
-
##### Results on SuperGLUE tasks (dev set, single model, single-task finetuning)
Model | BoolQ | CB | COPA | MultiRC | RTE | WiC | WSC
---|---|---|---|---|---|---|---
-`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | 91.3
+`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | -
+`roberta.large.wsc` | - | - | - | - | - | - | 91.3
##### Results on SQuAD (dev set)
@@ -83,28 +84,6 @@ assert len(all_layers) == 25
assert torch.all(all_layers[-1] == last_layer_features)
```
-By default RoBERTa outputs one feature vector per BPE token. You can instead
-realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
-with the `extract_features_aligned_to_words` method. This will compute a
-weighted average of the BPE-level features for each word and expose them in
-spaCy's `Token.vector` attribute:
-```python
-doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
-assert len(doc) == 10
-for tok in doc:
- print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
-# tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=) (...)
-# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=) (...)
-# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=) (...)
-# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=) (...)
-# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=) (...)
-# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=) (...)
-# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=) (...)
-# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...)
-# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...)
-# tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=) (...)
-```
-
##### Use RoBERTa for sentence-pair classification tasks:
```python
# Download RoBERTa already finetuned for MNLI
@@ -141,22 +120,79 @@ roberta.cuda()
roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=)
```
-##### Filling mask:
-Some examples from the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/).
+## Advanced usage
+
+#### Filling masks:
+
+RoBERTa can be used to fill `` tokens in the input. Some examples from the
+[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
```python
->>> roberta.fill_mask("The first Star wars movie came out in ", topk=3)
-[('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)]
+roberta.fill_mask('The first Star wars movie came out in ', topk=3)
+# [('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)]
+
+roberta.fill_mask('Vikram samvat calender is official in ', topk=3)
+# [('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)]
+
+roberta.fill_mask(' is the common currency of the European Union', topk=3)
+# [('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)]
+```
->>> roberta.fill_mask("Vikram samvat calender is official in ", topk=3)
-[('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)]
+#### Pronoun disambiguation (Winograd Schema Challenge):
->>> roberta.fill_mask(" is the common currency of the European Union", topk=3)
-[('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)]
+RoBERTa can be used to disambiguate pronouns. First install spaCy and download the English-language model:
+```bash
+pip install spacy
+python -m spacy download en_core_web_lg
+```
+
+Next load the `roberta.large.wsc` model and call the `disambiguate_pronoun`
+function. The pronoun should be surrounded by square brackets (`[]`) and the
+query referent surrounded by underscores (`_`), or left blank to return the
+predicted candidate text directly:
+```python
+roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.wsc', user_dir='examples/roberta/wsc')
+roberta.cuda() # use the GPU (optional)
+
+roberta.disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
+# True
+roberta.disambiguate_pronoun('The trophy would not fit in the brown _suitcase_ because [it] was too big.')
+# False
+
+roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] feared violence.')
+# 'The city councilmen'
+roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] advocated violence.')
+# 'demonstrators'
+```
+
+See the [RoBERTA Winograd Schema Challenge (WSC) README](README.wsc.md) for more details on how to train this model.
+
+#### Extract features aligned to words:
+
+By default RoBERTa outputs one feature vector per BPE token. You can instead
+realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
+with the `extract_features_aligned_to_words` method. This will compute a
+weighted average of the BPE-level features for each word and expose them in
+spaCy's `Token.vector` attribute:
+```python
+doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
+assert len(doc) == 10
+for tok in doc:
+ print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
+# tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=) (...)
+# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=) (...)
+# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=) (...)
+# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=) (...)
+# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=) (...)
+# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=) (...)
+# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=) (...)
+# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...)
+# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...)
+# tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=) (...)
```
-##### Evaluating the `roberta.large.mnli` model
+#### Evaluating the `roberta.large.mnli` model:
-Example python code snippet to evaluate accuracy on the MNLI dev_matched set.
+Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
```python
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
ncorrect, nsamples = 0, 0
@@ -181,6 +217,7 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples))
- [Finetuning on GLUE](README.finetune_glue.md)
- [Finetuning on custom classification tasks (e.g., IMDB)](README.finetune_custom_classification.md)
+- [Finetuning on Winograd Schema Challenge (WSC)](README.wsc.md)
- Finetuning on SQuAD: coming soon
## Pretraining using your own data
diff --git a/examples/roberta/README.wsc.md b/examples/roberta/README.wsc.md
new file mode 100644
index 0000000000..b1437d1de7
--- /dev/null
+++ b/examples/roberta/README.wsc.md
@@ -0,0 +1,83 @@
+# Finetuning RoBERTa on Winograd Schema Challenge (WSC) data
+
+The following instructions can be used to finetune RoBERTa on the WSC training
+data provided by [SuperGLUE](https://super.gluebenchmark.com/).
+
+Note that there is high variance in the results. For our GLUE/SuperGLUE
+submission we swept over the learning rate, batch size and total number of
+updates, as well as the random seed. Out of ~100 runs we chose the best 7 models
+and ensembled them.
+
+**Note:** The instructions below use a slightly different loss function than
+what's described in the original RoBERTa arXiv paper. In particular,
+[Kocijan et al. (2019)](https://arxiv.org/abs/1905.06290) introduce a margin
+ranking loss between `(query, candidate)` pairs with tunable hyperparameters
+alpha and beta. This is supported in our code as well with the `--wsc-alpha` and
+`--wsc-beta` arguments. However, we achieved slightly better (and more robust)
+results on the development set by instead using a single cross entropy loss term
+over the log-probabilities for the query and all candidates. This reduces the
+number of hyperparameters and our best model achieved 92.3% development set
+accuracy, compared to ~90% accuracy for the margin loss. Later versions of the
+RoBERTa arXiv paper will describe this updated formulation.
+
+### 1) Download the WSC data from the SuperGLUE website:
+```bash
+wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip
+unzip WSC.zip
+
+# we also need to copy the RoBERTa dictionary into the same directory
+wget -O WSC/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
+```
+
+### 2) Finetune over the provided training data:
+```bash
+TOTAL_NUM_UPDATES=2000 # Total number of training steps.
+WARMUP_UPDATES=250 # Linearly increase LR over this many steps.
+LR=2e-05 # Peak LR for polynomial LR scheduler.
+MAX_SENTENCES=16 # Batch size per GPU.
+SEED=1 # Random seed.
+ROBERTA_PATH=/path/to/roberta/model.pt
+
+# we use the --user-dir option to load the task and criterion
+# from the examples/roberta/wsc directory:
+FAIRSEQ_PATH=/path/to/fairseq
+FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
+
+cd fairseq
+CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
+ --restore-file $ROBERTA_PATH \
+ --reset-optimizer --reset-dataloader --reset-meters \
+ --no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
+ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
+ --valid-subset val \
+ --fp16 --ddp-backend no_c10d \
+ --user-dir $FAIRSEQ_USER_DIR \
+ --task wsc --criterion wsc --wsc-cross-entropy \
+ --arch roberta_large --bpe gpt2 --max-positions 512 \
+ --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
+ --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
+ --lr-scheduler polynomial_decay --lr $LR \
+ --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
+ --max-sentences $MAX_SENTENCES \
+ --max-update $TOTAL_NUM_UPDATES \
+ --log-format simple --log-interval 100
+```
+
+The above command assumes training on 4 GPUs, but you can achieve the same
+results on a single GPU by adding `--update-freq=4`.
+
+### 3) Evaluate
+```python
+from fairseq.models.roberta import RobertaModel
+from examples.roberta.wsc import wsc_utils # also loads WSC task and criterion
+roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'WSC/')
+roberta.cuda()
+nsamples, ncorrect = 0, 0
+for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True):
+ pred = roberta.disambiguate_pronoun(sentence)
+ nsamples += 1
+ if pred == label:
+ ncorrect += 1
+print('Accuracy: ' + str(ncorrect / float(nsamples)))
+# Accuracy: 0.9230769230769231
+```
diff --git a/examples/roberta/wsc/__init__.py b/examples/roberta/wsc/__init__.py
new file mode 100644
index 0000000000..78afa4728e
--- /dev/null
+++ b/examples/roberta/wsc/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import wsc_criterion # noqa
+from . import wsc_task # noqa
diff --git a/examples/roberta/wsc/wsc_criterion.py b/examples/roberta/wsc/wsc_criterion.py
new file mode 100644
index 0000000000..c5b6507f9a
--- /dev/null
+++ b/examples/roberta/wsc/wsc_criterion.py
@@ -0,0 +1,131 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+from fairseq import utils
+from fairseq.data import encoders
+from fairseq.criterions import FairseqCriterion, register_criterion
+
+
+@register_criterion('wsc')
+class WSCCriterion(FairseqCriterion):
+
+ def __init__(self, args, task):
+ super().__init__(args, task)
+ if self.args.save_predictions is not None:
+ self.prediction_h = open(self.args.save_predictions, 'w')
+ else:
+ self.prediction_h = None
+ self.bpe = encoders.build_bpe(args)
+ self.tokenizer = encoders.build_tokenizer(args)
+
+ def __del__(self):
+ if self.prediction_h is not None:
+ self.prediction_h.close()
+
+ @staticmethod
+ def add_args(parser):
+ """Add criterion-specific arguments to the parser."""
+ parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0)
+ parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0)
+ parser.add_argument('--wsc-cross-entropy', action='store_true',
+ help='use cross entropy formulation instead of margin loss')
+ parser.add_argument('--save-predictions', metavar='FILE',
+ help='file to save predictions to')
+
+ def forward(self, model, sample, reduce=True):
+
+ def get_masked_input(tokens, mask):
+ masked_tokens = tokens.clone()
+ masked_tokens[mask] = self.task.mask
+ return masked_tokens
+
+ def get_lprobs(tokens, mask):
+ logits, _ = model(src_tokens=get_masked_input(tokens, mask))
+ lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
+ scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
+ mask = mask.type_as(scores)
+ scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
+ return scores
+
+ # compute loss and accuracy
+ loss, nloss = 0., 0
+ ncorrect, nqueries = 0, 0
+ for i, label in enumerate(sample['labels']):
+ query_lprobs = get_lprobs(
+ sample['query_tokens'][i].unsqueeze(0),
+ sample['query_masks'][i].unsqueeze(0),
+ )
+ cand_lprobs = get_lprobs(
+ sample['candidate_tokens'][i],
+ sample['candidate_masks'][i],
+ )
+
+ pred = (query_lprobs >= cand_lprobs).all().item()
+
+ if label is not None:
+ label = 1 if label else 0
+ ncorrect += 1 if pred == label else 0
+ nqueries += 1
+
+ if label:
+ # only compute a loss for positive instances
+ nloss += 1
+ if self.args.wsc_cross_entropy:
+ loss += F.cross_entropy(
+ torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
+ query_lprobs.new([0]).long(),
+ )
+ else:
+ loss += (
+ - query_lprobs
+ + self.args.wsc_margin_alpha * (
+ cand_lprobs - query_lprobs + self.args.wsc_margin_beta
+ ).clamp(min=0)
+ ).sum()
+
+ id = sample['id'][i].item()
+ if self.prediction_h is not None:
+ print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
+
+ if nloss == 0:
+ loss = torch.tensor(0.0, requires_grad=True)
+
+ sample_size = nqueries if nqueries > 0 else 1
+ logging_output = {
+ 'loss': utils.item(loss.data) if reduce else loss.data,
+ 'ntokens': sample['ntokens'],
+ 'nsentences': sample['nsentences'],
+ 'sample_size': sample_size,
+ 'ncorrect': ncorrect,
+ 'nqueries': nqueries,
+ }
+ return loss, sample_size, logging_output
+
+ @staticmethod
+ def aggregate_logging_outputs(logging_outputs):
+ """Aggregate logging outputs from data parallel training."""
+ loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
+ ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
+ nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
+ sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
+
+ agg_output = {
+ 'loss': loss_sum / sample_size / math.log(2),
+ 'ntokens': ntokens,
+ 'nsentences': nsentences,
+ 'sample_size': sample_size,
+ }
+
+ ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
+ nqueries = sum(log.get('nqueries', 0) for log in logging_outputs)
+ if nqueries > 0:
+ agg_output['accuracy'] = ncorrect / float(nqueries)
+
+ return agg_output
diff --git a/examples/roberta/wsc/wsc_task.py b/examples/roberta/wsc/wsc_task.py
new file mode 100644
index 0000000000..7fd09fc77c
--- /dev/null
+++ b/examples/roberta/wsc/wsc_task.py
@@ -0,0 +1,260 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import tempfile
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from fairseq import utils
+from fairseq.data import (
+ data_utils,
+ Dictionary,
+ encoders,
+ IdDataset,
+ ListDataset,
+ NestedDictionaryDataset,
+ NumSamplesDataset,
+ NumelDataset,
+ SortDataset,
+)
+from fairseq.tasks import FairseqTask, register_task
+
+from . import wsc_utils
+
+
+@register_task('wsc')
+class WSCTask(FairseqTask):
+ """Task to finetune RoBERTa for Winograd Schemas."""
+
+ @staticmethod
+ def add_args(parser):
+ """Add task-specific arguments to the parser."""
+ parser.add_argument('data', metavar='DIR',
+ help='path to data directory; we load .jsonl')
+ parser.add_argument('--init-token', type=int, default=None,
+ help='add token at the beginning of each batch item')
+
+ def __init__(self, args, vocab):
+ super().__init__(args)
+ self.vocab = vocab
+ self.mask = vocab.add_symbol('')
+
+ self.bpe = encoders.build_bpe(args)
+ self.tokenizer = encoders.build_tokenizer(args)
+
+ # hack to handle GPT-2 BPE, which includes leading spaces
+ if args.bpe == 'gpt2':
+ self.leading_space = True
+ self.trailing_space = False
+ else:
+ self.leading_space = False
+ self.trailing_space = True
+
+ @classmethod
+ def load_dictionary(cls, filename):
+ """Load the dictionary from the filename
+
+ Args:
+ filename (str): the filename
+ """
+ dictionary = Dictionary.load(filename)
+ dictionary.add_symbol('')
+ return dictionary
+
+ @classmethod
+ def setup_task(cls, args, **kwargs):
+ assert args.criterion == 'wsc', 'Must set --criterion=wsc'
+
+ # load data and label dictionaries
+ vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
+ print('| dictionary: {} types'.format(len(vocab)))
+
+ return cls(args, vocab)
+
+ def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
+ """Load a given dataset split.
+
+ Args:
+ split (str): name of the split (e.g., train, valid, test)
+ """
+
+ def binarize(s: str, append_eos: bool = False):
+ if self.tokenizer is not None:
+ s = self.tokenizer.encode(s)
+ if self.bpe is not None:
+ s = self.bpe.encode(s)
+ tokens = self.vocab.encode_line(
+ s, append_eos=append_eos, add_if_not_exist=False,
+ ).long()
+ if self.args.init_token is not None:
+ tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
+ return tokens
+
+ if data_path is None:
+ data_path = os.path.join(self.args.data, split + '.jsonl')
+ if not os.path.exists(data_path):
+ raise FileNotFoundError('Cannot find data: {}'.format(data_path))
+
+ query_tokens = []
+ query_masks = []
+ query_lengths = []
+ candidate_tokens = []
+ candidate_masks = []
+ candidate_lengths = []
+ labels = []
+
+ for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
+ prefix = sentence[:pronoun_span.start].text
+ suffix = sentence[pronoun_span.end:].text_with_ws
+
+ # spaCy spans include trailing spaces, but we need to know about
+ # leading spaces for the GPT-2 BPE
+ leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else ''
+ trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else ''
+
+ # get noun phrases, excluding pronouns and anything overlapping with the query
+ cand_spans = wsc_utils.filter_noun_chunks(
+ wsc_utils.extended_noun_chunks(sentence),
+ exclude_pronouns=True,
+ exclude_query=query,
+ exact_match=False,
+ )
+
+ def binarize_with_mask(txt):
+ toks = binarize(
+ prefix + leading_space + txt + trailing_space + suffix,
+ append_eos=True,
+ )
+ mask = torch.zeros_like(toks, dtype=torch.uint8)
+ mask_start = len(binarize(prefix))
+ mask_size = len(binarize(leading_space + txt))
+ mask[mask_start:mask_start + mask_size] = 1
+ return toks, mask
+
+ if query is not None:
+ query_toks, query_mask = binarize_with_mask(query)
+ query_len = len(query_toks)
+ else:
+ query_toks, query_mask, query_len = None, None, 0
+
+ query_tokens.append(query_toks)
+ query_masks.append(query_mask)
+ query_lengths.append(query_len)
+
+ cand_toks, cand_masks = [], []
+ for cand_span in cand_spans:
+ toks, mask = binarize_with_mask(cand_span.text)
+ cand_toks.append(toks)
+ cand_masks.append(mask)
+
+ # collate candidates
+ cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
+ cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
+ assert cand_toks.size() == cand_masks.size()
+
+ candidate_tokens.append(cand_toks)
+ candidate_masks.append(cand_masks)
+ candidate_lengths.append(cand_toks.size(1))
+
+ labels.append(label)
+
+ query_lengths = np.array(query_lengths)
+ query_tokens = ListDataset(query_tokens, query_lengths)
+ query_masks = ListDataset(query_masks, query_lengths)
+
+ candidate_lengths = np.array(candidate_lengths)
+ candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
+ candidate_masks = ListDataset(candidate_masks, candidate_lengths)
+
+ labels = ListDataset(labels, [1]*len(labels))
+
+ dataset = {
+ 'id': IdDataset(),
+ 'query_tokens': query_tokens,
+ 'query_masks': query_masks,
+ 'candidate_tokens': candidate_tokens,
+ 'candidate_masks': candidate_masks,
+ 'labels': labels,
+ 'nsentences': NumSamplesDataset(),
+ 'ntokens': NumelDataset(query_tokens, reduce=True),
+ }
+
+ nested_dataset = NestedDictionaryDataset(
+ dataset,
+ sizes=[query_lengths],
+ )
+
+ with data_utils.numpy_seed(self.args.seed):
+ shuffle = np.random.permutation(len(query_tokens))
+ dataset = SortDataset(
+ nested_dataset,
+ # shuffle
+ sort_order=[shuffle],
+ )
+
+ if return_only:
+ return dataset
+
+ self.datasets[split] = dataset
+ return self.datasets[split]
+
+ def build_dataset_for_inference(self, sample_json):
+ with tempfile.NamedTemporaryFile(buffering=0) as h:
+ h.write((json.dumps(sample_json) + '\n').encode('utf-8'))
+ dataset = self.load_dataset(
+ 'disambiguate_pronoun',
+ data_path=h.name,
+ return_only=True,
+ )
+ return dataset
+
+ def disambiguate_pronoun(self, model, sentence, use_cuda=False):
+ sample_json = wsc_utils.convert_sentence_to_json(sentence)
+ dataset = self.build_dataset_for_inference(sample_json)
+ sample = dataset.collater([dataset[0]])
+ if use_cuda:
+ sample = utils.move_to_cuda(sample)
+
+ def get_masked_input(tokens, mask):
+ masked_tokens = tokens.clone()
+ masked_tokens[mask] = self.mask
+ return masked_tokens
+
+ def get_lprobs(tokens, mask):
+ logits, _ = model(src_tokens=get_masked_input(tokens, mask))
+ lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
+ scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
+ mask = mask.type_as(scores)
+ scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
+ return scores
+
+ cand_lprobs = get_lprobs(
+ sample['candidate_tokens'][0],
+ sample['candidate_masks'][0],
+ )
+ if sample['query_tokens'][0] is not None:
+ query_lprobs = get_lprobs(
+ sample['query_tokens'][0].unsqueeze(0),
+ sample['query_masks'][0].unsqueeze(0),
+ )
+ return (query_lprobs >= cand_lprobs).all().item() == 1
+ else:
+ best_idx = cand_lprobs.argmax().item()
+ full_cand = sample['candidate_tokens'][0][best_idx]
+ mask = sample['candidate_masks'][0][best_idx]
+ toks = full_cand[mask]
+ return self.bpe.decode(self.source_dictionary.string(toks)).strip()
+
+ @property
+ def source_dictionary(self):
+ return self.vocab
+
+ @property
+ def target_dictionary(self):
+ return self.vocab
diff --git a/examples/roberta/wsc/wsc_utils.py b/examples/roberta/wsc/wsc_utils.py
new file mode 100644
index 0000000000..ef388665fd
--- /dev/null
+++ b/examples/roberta/wsc/wsc_utils.py
@@ -0,0 +1,219 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import lru_cache
+import json
+
+
+def convert_sentence_to_json(sentence):
+ if '_' in sentence:
+ prefix, rest = sentence.split('_', 1)
+ query, rest = rest.split('_', 1)
+ query_index = len(prefix.rstrip().split(' '))
+ else:
+ query, query_index = None, None
+
+ prefix, rest = sentence.split('[', 1)
+ pronoun, rest = rest.split(']', 1)
+ pronoun_index = len(prefix.rstrip().split(' '))
+
+ sentence = sentence.replace('_', '').replace('[', '').replace(']', '')
+
+ return {
+ 'idx': 0,
+ 'text': sentence,
+ 'target': {
+ 'span1_index': query_index,
+ 'span1_text': query,
+ 'span2_index': pronoun_index,
+ 'span2_text': pronoun,
+ },
+ }
+
+
+def extended_noun_chunks(sentence):
+ noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
+ np_start, cur_np = 0, 'NONE'
+ for i, token in enumerate(sentence):
+ np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE'
+ if np_type != cur_np:
+ if cur_np != 'NONE':
+ noun_chunks.add((np_start, i))
+ if np_type != 'NONE':
+ np_start = i
+ cur_np = np_type
+ if cur_np != 'NONE':
+ noun_chunks.add((np_start, len(sentence)))
+ return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
+
+
+def find_token(sentence, start_pos):
+ found_tok = None
+ for tok in sentence:
+ if tok.idx == start_pos:
+ found_tok = tok
+ break
+ return found_tok
+
+
+def find_span(sentence, search_text, start=0):
+ search_text = search_text.lower()
+ for tok in sentence[start:]:
+ remainder = sentence[tok.i:].text.lower()
+ if remainder.startswith(search_text):
+ len_to_consume = len(search_text)
+ start_idx = tok.idx
+ for next_tok in sentence[tok.i:]:
+ end_idx = next_tok.idx + len(next_tok.text)
+ if end_idx - start_idx == len_to_consume:
+ span = sentence[tok.i:next_tok.i + 1]
+ return span
+ return None
+
+
+@lru_cache(maxsize=1)
+def get_detokenizer():
+ from sacremoses import MosesDetokenizer
+ detok = MosesDetokenizer(lang='en')
+ return detok
+
+
+@lru_cache(maxsize=1)
+def get_spacy_nlp():
+ import en_core_web_lg
+ nlp = en_core_web_lg.load()
+ return nlp
+
+
+def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
+ detok = get_detokenizer()
+ nlp = get_spacy_nlp()
+
+ with open(input_fname) as fin:
+ for line in fin:
+ sample = json.loads(line.strip())
+
+ if positive_only and 'label' in sample and not sample['label']:
+ # only consider examples where the query is correct
+ continue
+
+ target = sample['target']
+
+ # clean up the query
+ query = target['span1_text']
+ if query is not None:
+ if '\n' in query:
+ continue
+ if query.endswith('.') or query.endswith(','):
+ query = query[:-1]
+
+ # split tokens
+ tokens = sample['text'].split(' ')
+
+ def strip_pronoun(x):
+ return x.rstrip('.,"')
+
+ # find the pronoun
+ pronoun_idx = target['span2_index']
+ pronoun = strip_pronoun(target['span2_text'])
+ if strip_pronoun(tokens[pronoun_idx]) != pronoun:
+ # hack: sometimes the index is misaligned
+ if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
+ pronoun_idx += 1
+ else:
+ raise Exception('Misaligned pronoun!')
+ assert strip_pronoun(tokens[pronoun_idx]) == pronoun
+
+ # split tokens before and after the pronoun
+ before = tokens[:pronoun_idx]
+ after = tokens[pronoun_idx + 1:]
+
+ # the GPT BPE attaches leading spaces to tokens, so we keep track
+ # of whether we need spaces before or after the pronoun
+ leading_space = ' ' if pronoun_idx > 0 else ''
+ trailing_space = ' ' if len(after) > 0 else ''
+
+ # detokenize
+ before = detok.detokenize(before, return_str=True)
+ pronoun = detok.detokenize([pronoun], return_str=True)
+ after = detok.detokenize(after, return_str=True)
+
+ # hack: when the pronoun ends in a period (or comma), move the
+ # punctuation to the "after" part
+ if pronoun.endswith('.') or pronoun.endswith(','):
+ after = pronoun[-1] + trailing_space + after
+ pronoun = pronoun[:-1]
+
+ # hack: when the "after" part begins with a comma or period, remove
+ # the trailing space
+ if after.startswith('.') or after.startswith(','):
+ trailing_space = ''
+
+ # parse sentence with spacy
+ sentence = nlp(before + leading_space + pronoun + trailing_space + after)
+
+ # find pronoun span
+ start = len(before + leading_space)
+ first_pronoun_tok = find_token(sentence, start_pos=start)
+ pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
+ assert pronoun_span.text == pronoun
+
+ if eval:
+ # convert to format where pronoun is surrounded by "[]" and
+ # query is surrounded by "_"
+ query_span = find_span(sentence, query)
+ query_with_ws = '_{}_{}'.format(
+ query_span.text,
+ (' ' if query_span.text_with_ws.endswith(' ') else '')
+ )
+ pronoun_with_ws = '[{}]{}'.format(
+ pronoun_span.text,
+ (' ' if pronoun_span.text_with_ws.endswith(' ') else '')
+ )
+ if query_span.start < pronoun_span.start:
+ first = (query_span, query_with_ws)
+ second = (pronoun_span, pronoun_with_ws)
+ else:
+ first = (pronoun_span, pronoun_with_ws)
+ second = (query_span, query_with_ws)
+ sentence = (
+ sentence[:first[0].start].text_with_ws
+ + first[1]
+ + sentence[first[0].end:second[0].start].text_with_ws
+ + second[1]
+ + sentence[second[0].end:].text
+ )
+ yield sentence, sample.get('label', None)
+ else:
+ yield sentence, pronoun_span, query, sample.get('label', None)
+
+
+def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False):
+ if exclude_pronouns:
+ chunks = [
+ np for np in chunks if (
+ np.lemma_ != '-PRON-'
+ and not all(tok.pos_ == 'PRON' for tok in np)
+ )
+ ]
+
+ if exclude_query is not None:
+ excl_txt = [exclude_query.lower()]
+ filtered_chunks = []
+ for chunk in chunks:
+ lower_chunk = chunk.text.lower()
+ found = False
+ for excl in excl_txt:
+ if (
+ (not exact_match and (lower_chunk in excl or excl in lower_chunk))
+ or lower_chunk == excl
+ ):
+ found = True
+ break
+ if not found:
+ filtered_chunks.append(chunk)
+ chunks = filtered_chunks
+
+ return chunks
diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py
index 70cb948270..0e080c9ae3 100644
--- a/fairseq/checkpoint_utils.py
+++ b/fairseq/checkpoint_utils.py
@@ -345,6 +345,8 @@ def load_pretrained_component_from_model(
def verify_checkpoint_directory(save_dir: str) -> None:
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir, exist_ok=True)
temp_file_path = os.path.join(save_dir, 'dummy')
try:
with open(temp_file_path, 'w'):
diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py
index d400a9b034..f97eaa9fab 100644
--- a/fairseq/data/__init__.py
+++ b/fairseq/data/__init__.py
@@ -16,6 +16,7 @@
from .id_dataset import IdDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset
+from .list_dataset import ListDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset
@@ -59,6 +60,7 @@
'IndexedRawTextDataset',
'LanguagePairDataset',
'LeftPadDataset',
+ 'ListDataset',
'LMContextWindowDataset',
'LRUCacheDataset',
'MaskTokensDataset',
diff --git a/fairseq/data/list_dataset.py b/fairseq/data/list_dataset.py
new file mode 100644
index 0000000000..f753727abf
--- /dev/null
+++ b/fairseq/data/list_dataset.py
@@ -0,0 +1,29 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from . import BaseWrapperDataset
+
+
+class ListDataset(BaseWrapperDataset):
+
+ def __init__(self, dataset, sizes):
+ super().__init__(dataset)
+ self._sizes = sizes
+
+ def collater(self, samples):
+ return samples
+
+ @property
+ def sizes(self):
+ return self._sizes
+
+ def num_tokens(self, index):
+ return self.sizes[index]
+
+ def size(self, index):
+ return self.sizes[index]
+
+ def set_epoch(self, epoch):
+ pass
diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py
index 297338d02e..a1d37cd25b 100644
--- a/fairseq/hub_utils.py
+++ b/fairseq/hub_utils.py
@@ -47,6 +47,9 @@ def from_pretrained(
if os.path.exists(path):
kwargs[arg] = path
+ if 'user_dir' in kwargs:
+ utils.import_user_module(argparse.Namespace(user_dir=kwargs['user_dir']))
+
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs,
diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py
index 22ce96e89f..fc53384062 100644
--- a/fairseq/models/roberta/hub_interface.py
+++ b/fairseq/models/roberta/hub_interface.py
@@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from fairseq import utils
from fairseq.data import encoders
@@ -152,11 +153,12 @@ def fill_mask(self, masked_input: str, topk: int = 5):
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
- features, extra = self.model(
- tokens.long().to(device=self.device),
- features_only=False,
- return_all_hiddens=False,
- )
+ with utils.eval(self.model):
+ features, extra = self.model(
+ tokens.long().to(device=self.device),
+ features_only=False,
+ return_all_hiddens=False,
+ )
logits = features[0, masked_index, :].squeeze()
prob = logits.softmax(dim=0)
values, index = prob.topk(k=topk, dim=0)
@@ -178,3 +180,18 @@ def fill_mask(self, masked_input: str, topk: int = 5):
values[index].item(),
))
return topk_filled_outputs
+
+ def disambiguate_pronoun(self, sentence: str) -> bool:
+ """
+ Usage::
+
+ >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
+ True
+
+ >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.')
+ 'The trophy'
+ """
+ assert hasattr(self.task, 'disambiguate_pronoun'), \
+ 'roberta.disambiguate_pronoun() requires a model trained with the WSC task.'
+ with utils.eval(self.model):
+ return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda')
diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py
index eb7e03f764..8ae3f51f37 100644
--- a/fairseq/models/roberta/model.py
+++ b/fairseq/models/roberta/model.py
@@ -35,6 +35,7 @@ def hub_models(cls):
'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz',
'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz',
'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz',
+ 'roberta.large.wsc': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz',
}
def __init__(self, args, encoder):
diff --git a/fairseq/progress_bar.py b/fairseq/progress_bar.py
index a9bb5f97b4..59715f548d 100644
--- a/fairseq/progress_bar.py
+++ b/fairseq/progress_bar.py
@@ -14,8 +14,6 @@
import re
import sys
-from tqdm import tqdm
-
from fairseq import distributed_utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
@@ -208,6 +206,7 @@ class tqdm_progress_bar(progress_bar):
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
+ from tqdm import tqdm
self.tqdm = tqdm(iterable, self.prefix, leave=False)
def __iter__(self):
diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py
index f1686258fb..240ec0a3b5 100644
--- a/fairseq/tasks/masked_lm.py
+++ b/fairseq/tasks/masked_lm.py
@@ -104,6 +104,7 @@ def load_dataset(self, split, epoch=0, combine=False):
eos=self.source_dictionary.eos(),
break_mode=self.args.sample_break_mode,
)
+ print('| loaded {} batches from: {}'.format(len(dataset), split_path))
# prepend beginning-of-sentence token (, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
@@ -210,14 +211,3 @@ def source_dictionary(self):
@property
def target_dictionary(self):
return self.dictionary
-
- def get_average_masked_score(self, model, src_tokens, mask, **net_input):
- """Mask a set of tokens and return their average score."""
- masked_tokens = src_tokens.clone()
- masked_tokens[mask.byte()] = self.mask_idx
- net_output = model(src_tokens=masked_tokens, **net_input, last_state_only=True)
- lprobs = F.log_softmax(net_output[0], dim=-1, dtype=torch.float32)
- lprobs = lprobs.gather(-1, src_tokens.unsqueeze(-1)).squeeze(-1)
- mask = mask.type_as(lprobs)
- score = (lprobs * mask).sum(dim=-1) / mask.sum(dim=-1)
- return score
diff --git a/fairseq/tasks/translation_moe.py b/fairseq/tasks/translation_moe.py
index 35d44e47cb..cd8b985bb1 100644
--- a/fairseq/tasks/translation_moe.py
+++ b/fairseq/tasks/translation_moe.py
@@ -12,14 +12,6 @@
from fairseq.tasks.translation import TranslationTask
-@contextlib.contextmanager
-def eval(model):
- is_training = model.training
- model.eval()
- yield
- model.train(is_training)
-
-
@register_task('translation_moe')
class TranslationMoETask(TranslationTask):
"""
@@ -163,7 +155,7 @@ def get_lprob_yz(winners=None):
return lprob_yz
# compute responsibilities without dropout
- with eval(model): # disable dropout
+ with utils.eval(model): # disable dropout
with torch.no_grad(): # disable autograd
lprob_yz = get_lprob_yz() # B x K
prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
diff --git a/fairseq/utils.py b/fairseq/utils.py
index 76473837aa..4d9da12d62 100644
--- a/fairseq/utils.py
+++ b/fairseq/utils.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
+import contextlib
import copy
import importlib.util
import math
@@ -277,6 +278,10 @@ def import_user_module(args):
module_path = getattr(args, 'user_dir', None)
if module_path is not None:
module_path = os.path.abspath(args.user_dir)
+ if not os.path.exists(module_path):
+ fairseq_rel_path = os.path.join(os.path.dirname(__file__), '..', args.user_dir)
+ if os.path.exists(fairseq_rel_path):
+ module_path = fairseq_rel_path
module_parent, module_name = os.path.split(module_path)
if module_name not in sys.modules:
@@ -339,3 +344,11 @@ def get_available_activation_fns() -> List:
'tanh',
'linear',
]
+
+
+@contextlib.contextmanager
+def eval(model):
+ is_training = model.training
+ model.eval()
+ yield
+ model.train(is_training)
diff --git a/hubconf.py b/hubconf.py
index d8f252ad7b..34179c9dba 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -21,7 +21,7 @@
for model_name in _cls.hub_models().keys():
globals()[model_name] = functools.partial(
_cls.from_pretrained,
- model_name_or_path=model_name,
+ model_name,
)
# to simplify the interface we only expose named models
# globals()[_model_type] = _cls.from_pretrained