diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 5f3be7941d..5b80fe94cf 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -12,7 +12,8 @@ Model | Description | # params | Download ---|---|---|--- `roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz) `roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz) -`roberta.large.mnli` | `roberta.large` finetuned on MNLI | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz) +`roberta.large.mnli` | `roberta.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz) +`roberta.large.wsc` | `roberta.large` finetuned on [WSC](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz) ## Results @@ -24,12 +25,12 @@ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4 `roberta.large.mnli` | 90.2 | - | - | - | - | - | - | - - ##### Results on SuperGLUE tasks (dev set, single model, single-task finetuning) Model | BoolQ | CB | COPA | MultiRC | RTE | WiC | WSC ---|---|---|---|---|---|---|--- -`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | 91.3 +`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | - +`roberta.large.wsc` | - | - | - | - | - | - | 91.3 ##### Results on SQuAD (dev set) @@ -83,28 +84,6 @@ assert len(all_layers) == 25 assert torch.all(all_layers[-1] == last_layer_features) ``` -By default RoBERTa outputs one feature vector per BPE token. You can instead -realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization) -with the `extract_features_aligned_to_words` method. This will compute a -weighted average of the BPE-level features for each word and expose them in -spaCy's `Token.vector` attribute: -```python -doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."') -assert len(doc) == 10 -for tok in doc: - print('{:10}{} (...)'.format(str(tok), tok.vector[:5])) -# tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=) (...) -# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=) (...) -# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=) (...) -# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=) (...) -# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=) (...) -# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=) (...) -# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=) (...) -# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...) -# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...) -# tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=) (...) -``` - ##### Use RoBERTa for sentence-pair classification tasks: ```python # Download RoBERTa already finetuned for MNLI @@ -141,22 +120,79 @@ roberta.cuda() roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=) ``` -##### Filling mask: -Some examples from the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/). +## Advanced usage + +#### Filling masks: + +RoBERTa can be used to fill `` tokens in the input. Some examples from the +[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/): ```python ->>> roberta.fill_mask("The first Star wars movie came out in ", topk=3) -[('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)] +roberta.fill_mask('The first Star wars movie came out in ', topk=3) +# [('The first Star wars movie came out in 1977', 0.9504712224006653), ('The first Star wars movie came out in 1978', 0.009986752644181252), ('The first Star wars movie came out in 1979', 0.00957468245178461)] + +roberta.fill_mask('Vikram samvat calender is official in ', topk=3) +# [('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)] + +roberta.fill_mask(' is the common currency of the European Union', topk=3) +# [('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)] +``` ->>> roberta.fill_mask("Vikram samvat calender is official in ", topk=3) -[('Vikram samvat calender is official in India', 0.21878768503665924), ('Vikram samvat calender is official in Delhi', 0.08547217398881912), ('Vikram samvat calender is official in Gujarat', 0.07556255906820297)] +#### Pronoun disambiguation (Winograd Schema Challenge): ->>> roberta.fill_mask(" is the common currency of the European Union", topk=3) -[('Euro is the common currency of the European Union', 0.945650577545166), ('euro is the common currency of the European Union', 0.025747718289494514), ('€ is the common currency of the European Union', 0.011183015070855618)] +RoBERTa can be used to disambiguate pronouns. First install spaCy and download the English-language model: +```bash +pip install spacy +python -m spacy download en_core_web_lg +``` + +Next load the `roberta.large.wsc` model and call the `disambiguate_pronoun` +function. The pronoun should be surrounded by square brackets (`[]`) and the +query referent surrounded by underscores (`_`), or left blank to return the +predicted candidate text directly: +```python +roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.wsc', user_dir='examples/roberta/wsc') +roberta.cuda() # use the GPU (optional) + +roberta.disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.') +# True +roberta.disambiguate_pronoun('The trophy would not fit in the brown _suitcase_ because [it] was too big.') +# False + +roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] feared violence.') +# 'The city councilmen' +roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] advocated violence.') +# 'demonstrators' +``` + +See the [RoBERTA Winograd Schema Challenge (WSC) README](README.wsc.md) for more details on how to train this model. + +#### Extract features aligned to words: + +By default RoBERTa outputs one feature vector per BPE token. You can instead +realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization) +with the `extract_features_aligned_to_words` method. This will compute a +weighted average of the BPE-level features for each word and expose them in +spaCy's `Token.vector` attribute: +```python +doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."') +assert len(doc) == 10 +for tok in doc: + print('{:10}{} (...)'.format(str(tok), tok.vector[:5])) +# tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=) (...) +# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=) (...) +# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=) (...) +# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=) (...) +# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=) (...) +# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=) (...) +# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=) (...) +# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...) +# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...) +# tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=) (...) ``` -##### Evaluating the `roberta.large.mnli` model +#### Evaluating the `roberta.large.mnli` model: -Example python code snippet to evaluate accuracy on the MNLI dev_matched set. +Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set. ```python label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'} ncorrect, nsamples = 0, 0 @@ -181,6 +217,7 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples)) - [Finetuning on GLUE](README.finetune_glue.md) - [Finetuning on custom classification tasks (e.g., IMDB)](README.finetune_custom_classification.md) +- [Finetuning on Winograd Schema Challenge (WSC)](README.wsc.md) - Finetuning on SQuAD: coming soon ## Pretraining using your own data diff --git a/examples/roberta/README.wsc.md b/examples/roberta/README.wsc.md new file mode 100644 index 0000000000..b1437d1de7 --- /dev/null +++ b/examples/roberta/README.wsc.md @@ -0,0 +1,83 @@ +# Finetuning RoBERTa on Winograd Schema Challenge (WSC) data + +The following instructions can be used to finetune RoBERTa on the WSC training +data provided by [SuperGLUE](https://super.gluebenchmark.com/). + +Note that there is high variance in the results. For our GLUE/SuperGLUE +submission we swept over the learning rate, batch size and total number of +updates, as well as the random seed. Out of ~100 runs we chose the best 7 models +and ensembled them. + +**Note:** The instructions below use a slightly different loss function than +what's described in the original RoBERTa arXiv paper. In particular, +[Kocijan et al. (2019)](https://arxiv.org/abs/1905.06290) introduce a margin +ranking loss between `(query, candidate)` pairs with tunable hyperparameters +alpha and beta. This is supported in our code as well with the `--wsc-alpha` and +`--wsc-beta` arguments. However, we achieved slightly better (and more robust) +results on the development set by instead using a single cross entropy loss term +over the log-probabilities for the query and all candidates. This reduces the +number of hyperparameters and our best model achieved 92.3% development set +accuracy, compared to ~90% accuracy for the margin loss. Later versions of the +RoBERTa arXiv paper will describe this updated formulation. + +### 1) Download the WSC data from the SuperGLUE website: +```bash +wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip +unzip WSC.zip + +# we also need to copy the RoBERTa dictionary into the same directory +wget -O WSC/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt +``` + +### 2) Finetune over the provided training data: +```bash +TOTAL_NUM_UPDATES=2000 # Total number of training steps. +WARMUP_UPDATES=250 # Linearly increase LR over this many steps. +LR=2e-05 # Peak LR for polynomial LR scheduler. +MAX_SENTENCES=16 # Batch size per GPU. +SEED=1 # Random seed. +ROBERTA_PATH=/path/to/roberta/model.pt + +# we use the --user-dir option to load the task and criterion +# from the examples/roberta/wsc directory: +FAIRSEQ_PATH=/path/to/fairseq +FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc + +cd fairseq +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \ + --restore-file $ROBERTA_PATH \ + --reset-optimizer --reset-dataloader --reset-meters \ + --no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \ + --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ + --valid-subset val \ + --fp16 --ddp-backend no_c10d \ + --user-dir $FAIRSEQ_USER_DIR \ + --task wsc --criterion wsc --wsc-cross-entropy \ + --arch roberta_large --bpe gpt2 --max-positions 512 \ + --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ + --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \ + --lr-scheduler polynomial_decay --lr $LR \ + --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \ + --max-sentences $MAX_SENTENCES \ + --max-update $TOTAL_NUM_UPDATES \ + --log-format simple --log-interval 100 +``` + +The above command assumes training on 4 GPUs, but you can achieve the same +results on a single GPU by adding `--update-freq=4`. + +### 3) Evaluate +```python +from fairseq.models.roberta import RobertaModel +from examples.roberta.wsc import wsc_utils # also loads WSC task and criterion +roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'WSC/') +roberta.cuda() +nsamples, ncorrect = 0, 0 +for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True): + pred = roberta.disambiguate_pronoun(sentence) + nsamples += 1 + if pred == label: + ncorrect += 1 +print('Accuracy: ' + str(ncorrect / float(nsamples))) +# Accuracy: 0.9230769230769231 +``` diff --git a/examples/roberta/wsc/__init__.py b/examples/roberta/wsc/__init__.py new file mode 100644 index 0000000000..78afa4728e --- /dev/null +++ b/examples/roberta/wsc/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import wsc_criterion # noqa +from . import wsc_task # noqa diff --git a/examples/roberta/wsc/wsc_criterion.py b/examples/roberta/wsc/wsc_criterion.py new file mode 100644 index 0000000000..c5b6507f9a --- /dev/null +++ b/examples/roberta/wsc/wsc_criterion.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.data import encoders +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion('wsc') +class WSCCriterion(FairseqCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + if self.args.save_predictions is not None: + self.prediction_h = open(self.args.save_predictions, 'w') + else: + self.prediction_h = None + self.bpe = encoders.build_bpe(args) + self.tokenizer = encoders.build_tokenizer(args) + + def __del__(self): + if self.prediction_h is not None: + self.prediction_h.close() + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0) + parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0) + parser.add_argument('--wsc-cross-entropy', action='store_true', + help='use cross entropy formulation instead of margin loss') + parser.add_argument('--save-predictions', metavar='FILE', + help='file to save predictions to') + + def forward(self, model, sample, reduce=True): + + def get_masked_input(tokens, mask): + masked_tokens = tokens.clone() + masked_tokens[mask] = self.task.mask + return masked_tokens + + def get_lprobs(tokens, mask): + logits, _ = model(src_tokens=get_masked_input(tokens, mask)) + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float) + scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1) + mask = mask.type_as(scores) + scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1) + return scores + + # compute loss and accuracy + loss, nloss = 0., 0 + ncorrect, nqueries = 0, 0 + for i, label in enumerate(sample['labels']): + query_lprobs = get_lprobs( + sample['query_tokens'][i].unsqueeze(0), + sample['query_masks'][i].unsqueeze(0), + ) + cand_lprobs = get_lprobs( + sample['candidate_tokens'][i], + sample['candidate_masks'][i], + ) + + pred = (query_lprobs >= cand_lprobs).all().item() + + if label is not None: + label = 1 if label else 0 + ncorrect += 1 if pred == label else 0 + nqueries += 1 + + if label: + # only compute a loss for positive instances + nloss += 1 + if self.args.wsc_cross_entropy: + loss += F.cross_entropy( + torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0), + query_lprobs.new([0]).long(), + ) + else: + loss += ( + - query_lprobs + + self.args.wsc_margin_alpha * ( + cand_lprobs - query_lprobs + self.args.wsc_margin_beta + ).clamp(min=0) + ).sum() + + id = sample['id'][i].item() + if self.prediction_h is not None: + print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h) + + if nloss == 0: + loss = torch.tensor(0.0, requires_grad=True) + + sample_size = nqueries if nqueries > 0 else 1 + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['nsentences'], + 'sample_size': sample_size, + 'ncorrect': ncorrect, + 'nqueries': nqueries, + } + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get('loss', 0) for log in logging_outputs) + ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) + sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + + agg_output = { + 'loss': loss_sum / sample_size / math.log(2), + 'ntokens': ntokens, + 'nsentences': nsentences, + 'sample_size': sample_size, + } + + ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs) + nqueries = sum(log.get('nqueries', 0) for log in logging_outputs) + if nqueries > 0: + agg_output['accuracy'] = ncorrect / float(nqueries) + + return agg_output diff --git a/examples/roberta/wsc/wsc_task.py b/examples/roberta/wsc/wsc_task.py new file mode 100644 index 0000000000..7fd09fc77c --- /dev/null +++ b/examples/roberta/wsc/wsc_task.py @@ -0,0 +1,260 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import tempfile + +import numpy as np +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.data import ( + data_utils, + Dictionary, + encoders, + IdDataset, + ListDataset, + NestedDictionaryDataset, + NumSamplesDataset, + NumelDataset, + SortDataset, +) +from fairseq.tasks import FairseqTask, register_task + +from . import wsc_utils + + +@register_task('wsc') +class WSCTask(FairseqTask): + """Task to finetune RoBERTa for Winograd Schemas.""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument('data', metavar='DIR', + help='path to data directory; we load .jsonl') + parser.add_argument('--init-token', type=int, default=None, + help='add token at the beginning of each batch item') + + def __init__(self, args, vocab): + super().__init__(args) + self.vocab = vocab + self.mask = vocab.add_symbol('') + + self.bpe = encoders.build_bpe(args) + self.tokenizer = encoders.build_tokenizer(args) + + # hack to handle GPT-2 BPE, which includes leading spaces + if args.bpe == 'gpt2': + self.leading_space = True + self.trailing_space = False + else: + self.leading_space = False + self.trailing_space = True + + @classmethod + def load_dictionary(cls, filename): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + dictionary = Dictionary.load(filename) + dictionary.add_symbol('') + return dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + assert args.criterion == 'wsc', 'Must set --criterion=wsc' + + # load data and label dictionaries + vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt')) + print('| dictionary: {} types'.format(len(vocab))) + + return cls(args, vocab) + + def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + + def binarize(s: str, append_eos: bool = False): + if self.tokenizer is not None: + s = self.tokenizer.encode(s) + if self.bpe is not None: + s = self.bpe.encode(s) + tokens = self.vocab.encode_line( + s, append_eos=append_eos, add_if_not_exist=False, + ).long() + if self.args.init_token is not None: + tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) + return tokens + + if data_path is None: + data_path = os.path.join(self.args.data, split + '.jsonl') + if not os.path.exists(data_path): + raise FileNotFoundError('Cannot find data: {}'.format(data_path)) + + query_tokens = [] + query_masks = [] + query_lengths = [] + candidate_tokens = [] + candidate_masks = [] + candidate_lengths = [] + labels = [] + + for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path): + prefix = sentence[:pronoun_span.start].text + suffix = sentence[pronoun_span.end:].text_with_ws + + # spaCy spans include trailing spaces, but we need to know about + # leading spaces for the GPT-2 BPE + leading_space = ' ' if sentence[:pronoun_span.start].text_with_ws.endswith(' ') else '' + trailing_space = ' ' if pronoun_span.text_with_ws.endswith(' ') else '' + + # get noun phrases, excluding pronouns and anything overlapping with the query + cand_spans = wsc_utils.filter_noun_chunks( + wsc_utils.extended_noun_chunks(sentence), + exclude_pronouns=True, + exclude_query=query, + exact_match=False, + ) + + def binarize_with_mask(txt): + toks = binarize( + prefix + leading_space + txt + trailing_space + suffix, + append_eos=True, + ) + mask = torch.zeros_like(toks, dtype=torch.uint8) + mask_start = len(binarize(prefix)) + mask_size = len(binarize(leading_space + txt)) + mask[mask_start:mask_start + mask_size] = 1 + return toks, mask + + if query is not None: + query_toks, query_mask = binarize_with_mask(query) + query_len = len(query_toks) + else: + query_toks, query_mask, query_len = None, None, 0 + + query_tokens.append(query_toks) + query_masks.append(query_mask) + query_lengths.append(query_len) + + cand_toks, cand_masks = [], [] + for cand_span in cand_spans: + toks, mask = binarize_with_mask(cand_span.text) + cand_toks.append(toks) + cand_masks.append(mask) + + # collate candidates + cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad()) + cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0) + assert cand_toks.size() == cand_masks.size() + + candidate_tokens.append(cand_toks) + candidate_masks.append(cand_masks) + candidate_lengths.append(cand_toks.size(1)) + + labels.append(label) + + query_lengths = np.array(query_lengths) + query_tokens = ListDataset(query_tokens, query_lengths) + query_masks = ListDataset(query_masks, query_lengths) + + candidate_lengths = np.array(candidate_lengths) + candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) + candidate_masks = ListDataset(candidate_masks, candidate_lengths) + + labels = ListDataset(labels, [1]*len(labels)) + + dataset = { + 'id': IdDataset(), + 'query_tokens': query_tokens, + 'query_masks': query_masks, + 'candidate_tokens': candidate_tokens, + 'candidate_masks': candidate_masks, + 'labels': labels, + 'nsentences': NumSamplesDataset(), + 'ntokens': NumelDataset(query_tokens, reduce=True), + } + + nested_dataset = NestedDictionaryDataset( + dataset, + sizes=[query_lengths], + ) + + with data_utils.numpy_seed(self.args.seed): + shuffle = np.random.permutation(len(query_tokens)) + dataset = SortDataset( + nested_dataset, + # shuffle + sort_order=[shuffle], + ) + + if return_only: + return dataset + + self.datasets[split] = dataset + return self.datasets[split] + + def build_dataset_for_inference(self, sample_json): + with tempfile.NamedTemporaryFile(buffering=0) as h: + h.write((json.dumps(sample_json) + '\n').encode('utf-8')) + dataset = self.load_dataset( + 'disambiguate_pronoun', + data_path=h.name, + return_only=True, + ) + return dataset + + def disambiguate_pronoun(self, model, sentence, use_cuda=False): + sample_json = wsc_utils.convert_sentence_to_json(sentence) + dataset = self.build_dataset_for_inference(sample_json) + sample = dataset.collater([dataset[0]]) + if use_cuda: + sample = utils.move_to_cuda(sample) + + def get_masked_input(tokens, mask): + masked_tokens = tokens.clone() + masked_tokens[mask] = self.mask + return masked_tokens + + def get_lprobs(tokens, mask): + logits, _ = model(src_tokens=get_masked_input(tokens, mask)) + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float) + scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1) + mask = mask.type_as(scores) + scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1) + return scores + + cand_lprobs = get_lprobs( + sample['candidate_tokens'][0], + sample['candidate_masks'][0], + ) + if sample['query_tokens'][0] is not None: + query_lprobs = get_lprobs( + sample['query_tokens'][0].unsqueeze(0), + sample['query_masks'][0].unsqueeze(0), + ) + return (query_lprobs >= cand_lprobs).all().item() == 1 + else: + best_idx = cand_lprobs.argmax().item() + full_cand = sample['candidate_tokens'][0][best_idx] + mask = sample['candidate_masks'][0][best_idx] + toks = full_cand[mask] + return self.bpe.decode(self.source_dictionary.string(toks)).strip() + + @property + def source_dictionary(self): + return self.vocab + + @property + def target_dictionary(self): + return self.vocab diff --git a/examples/roberta/wsc/wsc_utils.py b/examples/roberta/wsc/wsc_utils.py new file mode 100644 index 0000000000..ef388665fd --- /dev/null +++ b/examples/roberta/wsc/wsc_utils.py @@ -0,0 +1,219 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache +import json + + +def convert_sentence_to_json(sentence): + if '_' in sentence: + prefix, rest = sentence.split('_', 1) + query, rest = rest.split('_', 1) + query_index = len(prefix.rstrip().split(' ')) + else: + query, query_index = None, None + + prefix, rest = sentence.split('[', 1) + pronoun, rest = rest.split(']', 1) + pronoun_index = len(prefix.rstrip().split(' ')) + + sentence = sentence.replace('_', '').replace('[', '').replace(']', '') + + return { + 'idx': 0, + 'text': sentence, + 'target': { + 'span1_index': query_index, + 'span1_text': query, + 'span2_index': pronoun_index, + 'span2_text': pronoun, + }, + } + + +def extended_noun_chunks(sentence): + noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks} + np_start, cur_np = 0, 'NONE' + for i, token in enumerate(sentence): + np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE' + if np_type != cur_np: + if cur_np != 'NONE': + noun_chunks.add((np_start, i)) + if np_type != 'NONE': + np_start = i + cur_np = np_type + if cur_np != 'NONE': + noun_chunks.add((np_start, len(sentence))) + return [sentence[s:e] for (s, e) in sorted(noun_chunks)] + + +def find_token(sentence, start_pos): + found_tok = None + for tok in sentence: + if tok.idx == start_pos: + found_tok = tok + break + return found_tok + + +def find_span(sentence, search_text, start=0): + search_text = search_text.lower() + for tok in sentence[start:]: + remainder = sentence[tok.i:].text.lower() + if remainder.startswith(search_text): + len_to_consume = len(search_text) + start_idx = tok.idx + for next_tok in sentence[tok.i:]: + end_idx = next_tok.idx + len(next_tok.text) + if end_idx - start_idx == len_to_consume: + span = sentence[tok.i:next_tok.i + 1] + return span + return None + + +@lru_cache(maxsize=1) +def get_detokenizer(): + from sacremoses import MosesDetokenizer + detok = MosesDetokenizer(lang='en') + return detok + + +@lru_cache(maxsize=1) +def get_spacy_nlp(): + import en_core_web_lg + nlp = en_core_web_lg.load() + return nlp + + +def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False): + detok = get_detokenizer() + nlp = get_spacy_nlp() + + with open(input_fname) as fin: + for line in fin: + sample = json.loads(line.strip()) + + if positive_only and 'label' in sample and not sample['label']: + # only consider examples where the query is correct + continue + + target = sample['target'] + + # clean up the query + query = target['span1_text'] + if query is not None: + if '\n' in query: + continue + if query.endswith('.') or query.endswith(','): + query = query[:-1] + + # split tokens + tokens = sample['text'].split(' ') + + def strip_pronoun(x): + return x.rstrip('.,"') + + # find the pronoun + pronoun_idx = target['span2_index'] + pronoun = strip_pronoun(target['span2_text']) + if strip_pronoun(tokens[pronoun_idx]) != pronoun: + # hack: sometimes the index is misaligned + if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun: + pronoun_idx += 1 + else: + raise Exception('Misaligned pronoun!') + assert strip_pronoun(tokens[pronoun_idx]) == pronoun + + # split tokens before and after the pronoun + before = tokens[:pronoun_idx] + after = tokens[pronoun_idx + 1:] + + # the GPT BPE attaches leading spaces to tokens, so we keep track + # of whether we need spaces before or after the pronoun + leading_space = ' ' if pronoun_idx > 0 else '' + trailing_space = ' ' if len(after) > 0 else '' + + # detokenize + before = detok.detokenize(before, return_str=True) + pronoun = detok.detokenize([pronoun], return_str=True) + after = detok.detokenize(after, return_str=True) + + # hack: when the pronoun ends in a period (or comma), move the + # punctuation to the "after" part + if pronoun.endswith('.') or pronoun.endswith(','): + after = pronoun[-1] + trailing_space + after + pronoun = pronoun[:-1] + + # hack: when the "after" part begins with a comma or period, remove + # the trailing space + if after.startswith('.') or after.startswith(','): + trailing_space = '' + + # parse sentence with spacy + sentence = nlp(before + leading_space + pronoun + trailing_space + after) + + # find pronoun span + start = len(before + leading_space) + first_pronoun_tok = find_token(sentence, start_pos=start) + pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i) + assert pronoun_span.text == pronoun + + if eval: + # convert to format where pronoun is surrounded by "[]" and + # query is surrounded by "_" + query_span = find_span(sentence, query) + query_with_ws = '_{}_{}'.format( + query_span.text, + (' ' if query_span.text_with_ws.endswith(' ') else '') + ) + pronoun_with_ws = '[{}]{}'.format( + pronoun_span.text, + (' ' if pronoun_span.text_with_ws.endswith(' ') else '') + ) + if query_span.start < pronoun_span.start: + first = (query_span, query_with_ws) + second = (pronoun_span, pronoun_with_ws) + else: + first = (pronoun_span, pronoun_with_ws) + second = (query_span, query_with_ws) + sentence = ( + sentence[:first[0].start].text_with_ws + + first[1] + + sentence[first[0].end:second[0].start].text_with_ws + + second[1] + + sentence[second[0].end:].text + ) + yield sentence, sample.get('label', None) + else: + yield sentence, pronoun_span, query, sample.get('label', None) + + +def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False): + if exclude_pronouns: + chunks = [ + np for np in chunks if ( + np.lemma_ != '-PRON-' + and not all(tok.pos_ == 'PRON' for tok in np) + ) + ] + + if exclude_query is not None: + excl_txt = [exclude_query.lower()] + filtered_chunks = [] + for chunk in chunks: + lower_chunk = chunk.text.lower() + found = False + for excl in excl_txt: + if ( + (not exact_match and (lower_chunk in excl or excl in lower_chunk)) + or lower_chunk == excl + ): + found = True + break + if not found: + filtered_chunks.append(chunk) + chunks = filtered_chunks + + return chunks diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 70cb948270..0e080c9ae3 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -345,6 +345,8 @@ def load_pretrained_component_from_model( def verify_checkpoint_directory(save_dir: str) -> None: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) temp_file_path = os.path.join(save_dir, 'dummy') try: with open(temp_file_path, 'w'): diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index d400a9b034..f97eaa9fab 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -16,6 +16,7 @@ from .id_dataset import IdDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset from .language_pair_dataset import LanguagePairDataset +from .list_dataset import ListDataset from .lm_context_window_dataset import LMContextWindowDataset from .lru_cache_dataset import LRUCacheDataset from .mask_tokens_dataset import MaskTokensDataset @@ -59,6 +60,7 @@ 'IndexedRawTextDataset', 'LanguagePairDataset', 'LeftPadDataset', + 'ListDataset', 'LMContextWindowDataset', 'LRUCacheDataset', 'MaskTokensDataset', diff --git a/fairseq/data/list_dataset.py b/fairseq/data/list_dataset.py new file mode 100644 index 0000000000..f753727abf --- /dev/null +++ b/fairseq/data/list_dataset.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class ListDataset(BaseWrapperDataset): + + def __init__(self, dataset, sizes): + super().__init__(dataset) + self._sizes = sizes + + def collater(self, samples): + return samples + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + def set_epoch(self, epoch): + pass diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 297338d02e..a1d37cd25b 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -47,6 +47,9 @@ def from_pretrained( if os.path.exists(path): kwargs[arg] = path + if 'user_dir' in kwargs: + utils.import_user_module(argparse.Namespace(user_dir=kwargs['user_dir'])) + models, args, task = checkpoint_utils.load_model_ensemble_and_task( [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')], arg_overrides=kwargs, diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index 22ce96e89f..fc53384062 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.nn.functional as F +from fairseq import utils from fairseq.data import encoders @@ -152,11 +153,12 @@ def fill_mask(self, masked_input: str, topk: int = 5): if tokens.dim() == 1: tokens = tokens.unsqueeze(0) - features, extra = self.model( - tokens.long().to(device=self.device), - features_only=False, - return_all_hiddens=False, - ) + with utils.eval(self.model): + features, extra = self.model( + tokens.long().to(device=self.device), + features_only=False, + return_all_hiddens=False, + ) logits = features[0, masked_index, :].squeeze() prob = logits.softmax(dim=0) values, index = prob.topk(k=topk, dim=0) @@ -178,3 +180,18 @@ def fill_mask(self, masked_input: str, topk: int = 5): values[index].item(), )) return topk_filled_outputs + + def disambiguate_pronoun(self, sentence: str) -> bool: + """ + Usage:: + + >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.') + True + + >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.') + 'The trophy' + """ + assert hasattr(self.task, 'disambiguate_pronoun'), \ + 'roberta.disambiguate_pronoun() requires a model trained with the WSC task.' + with utils.eval(self.model): + return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda') diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index eb7e03f764..8ae3f51f37 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -35,6 +35,7 @@ def hub_models(cls): 'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz', 'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz', 'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz', + 'roberta.large.wsc': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz', } def __init__(self, args, encoder): diff --git a/fairseq/progress_bar.py b/fairseq/progress_bar.py index a9bb5f97b4..59715f548d 100644 --- a/fairseq/progress_bar.py +++ b/fairseq/progress_bar.py @@ -14,8 +14,6 @@ import re import sys -from tqdm import tqdm - from fairseq import distributed_utils from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter @@ -208,6 +206,7 @@ class tqdm_progress_bar(progress_bar): def __init__(self, iterable, epoch=None, prefix=None): super().__init__(iterable, epoch, prefix) + from tqdm import tqdm self.tqdm = tqdm(iterable, self.prefix, leave=False) def __iter__(self): diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index f1686258fb..240ec0a3b5 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -104,6 +104,7 @@ def load_dataset(self, split, epoch=0, combine=False): eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) + print('| loaded {} batches from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) @@ -210,14 +211,3 @@ def source_dictionary(self): @property def target_dictionary(self): return self.dictionary - - def get_average_masked_score(self, model, src_tokens, mask, **net_input): - """Mask a set of tokens and return their average score.""" - masked_tokens = src_tokens.clone() - masked_tokens[mask.byte()] = self.mask_idx - net_output = model(src_tokens=masked_tokens, **net_input, last_state_only=True) - lprobs = F.log_softmax(net_output[0], dim=-1, dtype=torch.float32) - lprobs = lprobs.gather(-1, src_tokens.unsqueeze(-1)).squeeze(-1) - mask = mask.type_as(lprobs) - score = (lprobs * mask).sum(dim=-1) / mask.sum(dim=-1) - return score diff --git a/fairseq/tasks/translation_moe.py b/fairseq/tasks/translation_moe.py index 35d44e47cb..cd8b985bb1 100644 --- a/fairseq/tasks/translation_moe.py +++ b/fairseq/tasks/translation_moe.py @@ -12,14 +12,6 @@ from fairseq.tasks.translation import TranslationTask -@contextlib.contextmanager -def eval(model): - is_training = model.training - model.eval() - yield - model.train(is_training) - - @register_task('translation_moe') class TranslationMoETask(TranslationTask): """ @@ -163,7 +155,7 @@ def get_lprob_yz(winners=None): return lprob_yz # compute responsibilities without dropout - with eval(model): # disable dropout + with utils.eval(model): # disable dropout with torch.no_grad(): # disable autograd lprob_yz = get_lprob_yz() # B x K prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1) diff --git a/fairseq/utils.py b/fairseq/utils.py index 76473837aa..4d9da12d62 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from collections import defaultdict +import contextlib import copy import importlib.util import math @@ -277,6 +278,10 @@ def import_user_module(args): module_path = getattr(args, 'user_dir', None) if module_path is not None: module_path = os.path.abspath(args.user_dir) + if not os.path.exists(module_path): + fairseq_rel_path = os.path.join(os.path.dirname(__file__), '..', args.user_dir) + if os.path.exists(fairseq_rel_path): + module_path = fairseq_rel_path module_parent, module_name = os.path.split(module_path) if module_name not in sys.modules: @@ -339,3 +344,11 @@ def get_available_activation_fns() -> List: 'tanh', 'linear', ] + + +@contextlib.contextmanager +def eval(model): + is_training = model.training + model.eval() + yield + model.train(is_training) diff --git a/hubconf.py b/hubconf.py index d8f252ad7b..34179c9dba 100644 --- a/hubconf.py +++ b/hubconf.py @@ -21,7 +21,7 @@ for model_name in _cls.hub_models().keys(): globals()[model_name] = functools.partial( _cls.from_pretrained, - model_name_or_path=model_name, + model_name, ) # to simplify the interface we only expose named models # globals()[_model_type] = _cls.from_pretrained