From 138dc8e4fd02f81074842731d1bdd9401aa59489 Mon Sep 17 00:00:00 2001 From: Naman Goyal Date: Mon, 29 Jul 2019 16:03:11 -0700 Subject: [PATCH] adding glue data preprocessing scripts (#771) Summary: 1) Added glue data pre-processing script. 2) updated README with usage. TODO: 1) releasing fairseq dictionary and remove hardcoded path. 2) remove hard-coded path for bpe-encoding, myleott what do you recommend for above TODOs? Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/771 Reviewed By: myleott Differential Revision: D16547679 Pulled By: myleott fbshipit-source-id: 6a6562d9b6215523d048fdf3daee63ffac21e231 --- examples/roberta/README.md | 70 +++++- .../roberta/multiprocessing_bpe_encoder.py | 126 +++++++++++ examples/roberta/preprocess_GLUE_tasks.sh | 187 +++++++++++++++ fairseq/checkpoint_utils.py | 10 +- fairseq/criterions/sentence_prediction.py | 101 +++++++++ fairseq/data/__init__.py | 10 + fairseq/data/concat_sentences_dataset.py | 52 +++++ fairseq/data/offset_tokens_dataset.py | 18 ++ fairseq/data/raw_label_dataset.py | 26 +++ fairseq/data/strip_token_dataset.py | 19 ++ fairseq/data/truncate_dataset.py | 32 +++ fairseq/models/roberta/model.py | 9 + fairseq/tasks/sentence_prediction.py | 212 ++++++++++++++++++ train.py | 31 ++- 14 files changed, 892 insertions(+), 11 deletions(-) create mode 100644 examples/roberta/multiprocessing_bpe_encoder.py create mode 100755 examples/roberta/preprocess_GLUE_tasks.sh create mode 100644 fairseq/criterions/sentence_prediction.py create mode 100644 fairseq/data/concat_sentences_dataset.py create mode 100644 fairseq/data/offset_tokens_dataset.py create mode 100644 fairseq/data/raw_label_dataset.py create mode 100644 fairseq/data/strip_token_dataset.py create mode 100644 fairseq/data/truncate_dataset.py create mode 100644 fairseq/tasks/sentence_prediction.py diff --git a/examples/roberta/README.md b/examples/roberta/README.md index b7661d3784..c01595bfb8 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -134,9 +134,77 @@ print('| Accuracy: ', float(ncorrect)/float(nsamples)) # Expected output: 0.9060 ``` + ## Finetuning on GLUE tasks -A more detailed tutorial is coming soon. +##### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands: +``` +$ wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py +$ python download_glue_data.py --data_dir glue_data --tasks all +``` + +##### 2) Preprocess GLUE task data: +``` +$ ./examples/roberta/preprocess_GLUE_tasks.sh glue_data +``` +`glue_task_name` is one of the following: +`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}` +Use `ALL` for preprocessing all the glue tasks. + +##### 3) Fine-tuning on GLUE task : +Example fine-tuning cmd for `RTE` task +``` +TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16 +WARMUP_UPDATES=122 # 6 percent of the number of updates +LR=2e-05 # Peak LR for polynomial LR scheduler. +NUM_CLASSES=2 +MAX_SENTENCES=16 # Batch size. + +CUDA_VISIBLE_DEVICES=0 python train.py RTE-bin/ \ +--restore-file \ +--max-positions 512 \ +--max-sentences $MAX_SENTENCES \ +--max-tokens 4400 \ +--task sentence_prediction \ +--reset-optimizer --reset-dataloader --reset-meters \ +--required-batch-size-multiple 1 \ +--init-token 0 --separator-token 2 \ +--arch roberta_large \ +--criterion sentence_prediction \ +--num-classes $NUM_CLASSES \ +--dropout 0.1 --attention-dropout 0.1 \ +--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ +--clip-norm 0.0 \ +--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ +--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ +--max-epoch 10 \ +--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; +``` + +For each of the GLUE task, you will need to use following cmd-line arguments: + +Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B +---|---|---|---|---|---|---|---|--- +`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1 +`--lr` | 1e-5 | 1e-5 | 1e-5 | 2e-5 | 1e-5 | 1e-5 | 1e-5 | 2e-5 +`--max-sentences` | 32 | 32 | 32 | 16 | 32 | 16 | 16 | 16 +`--total-num-update` | 123873 | 33112 | 113272 | 2036 | 20935 | 2296 | 5336 | 3598 +`--warmup-updates` | 7432 | 1986 | 28318 | 122 | 1256 | 137 | 320 | 214 + +For `STS-B` additionally use following cmd-line argument: +``` +--regression-target +--best-checkpoint-metric loss +``` +and remove `--maximize-best-checkpoint-metric`. + +**Note:** + +a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=16/32` depending on the task. + +b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. + +c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search. ## Pretraining using your own data diff --git a/examples/roberta/multiprocessing_bpe_encoder.py b/examples/roberta/multiprocessing_bpe_encoder.py new file mode 100644 index 0000000000..48d9cb367e --- /dev/null +++ b/examples/roberta/multiprocessing_bpe_encoder.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import contextlib +import sys + +from collections import Counter +from multiprocessing import Pool + +from fairseq.data.encoders.gpt2_bpe import get_encoder + + +def main(): + """ + Helper script to encode raw text + with the GPT-2 BPE using multiple processes. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--encoder-json", + help='path to encoder.json', + ) + parser.add_argument( + "--vocab-bpe", + type=str, + help='path to vocab.bpe', + ) + parser.add_argument( + "--inputs", + nargs="+", + default=['-'], + help="input files to filter/encode", + ) + parser.add_argument( + "--outputs", + nargs="+", + default=['-'], + help="path to save encoded outputs", + ) + parser.add_argument( + "--keep-empty", + action="store_true", + help="keep empty lines", + ) + parser.add_argument("--workers", type=int, default=20) + args = parser.parse_args() + + assert len(args.inputs) == len(args.outputs), \ + "number of input and output paths should match" + + with contextlib.ExitStack() as stack: + inputs = [ + stack.enter_context(open(input, "r", encoding="utf-8")) + if input != "-" else sys.stdin + for input in args.inputs + ] + outputs = [ + stack.enter_context(open(output, "w", encoding="utf-8")) + if output != "-" else sys.stdout + for output in args.outputs + ] + + encoder = MultiprocessingEncoder(args) + pool = Pool(args.workers, initializer=encoder.initializer) + encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) + + stats = Counter() + for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): + if filt == "PASS": + for enc_line, output_h in zip(enc_lines, outputs): + print(enc_line, file=output_h) + else: + stats["num_filtered_" + filt] += 1 + if i % 10000 == 0: + print("processed {} lines".format(i), file=sys.stderr) + + for k, v in stats.most_common(): + print("[{}] filtered {} lines".format(k, v), file=sys.stderr) + + +class MultiprocessingEncoder(object): + + def __init__(self, args): + self.args = args + + def initializer(self): + global bpe + bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe) + + def encode(self, line): + global bpe + ids = bpe.encode(line) + return list(map(str, ids)) + + def decode(self, tokens): + global bpe + return bpe.decode(tokens) + + def encode_lines(self, lines): + """ + Encode a set of lines. All lines will be encoded together. + """ + enc_lines = [] + for line in lines: + line = line.strip() + if len(line) == 0 and not self.args.keep_empty: + return ["EMPTY", None] + tokens = self.encode(line) + enc_lines.append(" ".join(tokens)) + return ["PASS", enc_lines] + + def decode_lines(self, lines): + dec_lines = [] + for line in lines: + tokens = map(int, line.strip().split()) + dec_lines.append(self.decode(tokens)) + return ["PASS", dec_lines] + + +if __name__ == "__main__": + main() diff --git a/examples/roberta/preprocess_GLUE_tasks.sh b/examples/roberta/preprocess_GLUE_tasks.sh new file mode 100755 index 0000000000..33fcd8f4f5 --- /dev/null +++ b/examples/roberta/preprocess_GLUE_tasks.sh @@ -0,0 +1,187 @@ +#!/bin/bash +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +# raw glue data as downloaded by glue download script (https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) +if [[ $# -ne 2 ]]; then + echo "Run as following:" + echo "./examples/roberta/preprocess_GLUE_tasks.sh " + exit 1 +fi + +GLUE_DATA_FOLDER=$1 + +# download bpe encoder.json, vocabulary and fairseq dictionary +wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' +wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' +wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt' + +TASKS=$2 # QQP + +if [ "$TASKS" = "ALL" ] +then + TASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA" +fi + +for TASK in $TASKS +do + echo "Preprocessing $TASK" + + TASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK" + echo "Raw data as downloaded from glue website: $TASK_DATA_FOLDER" + + SPLITS="train dev test" + INPUT_COUNT=2 + if [ "$TASK" = "QQP" ] + then + INPUT_COLUMNS=( 4 5 ) + TEST_INPUT_COLUMNS=( 2 3 ) + LABEL_COLUMN=6 + elif [ "$TASK" = "MNLI" ] + then + SPLITS="train dev_matched dev_mismatched test_matched test_mismatched" + INPUT_COLUMNS=( 9 10 ) + TEST_INPUT_COLUMNS=( 9 10 ) + DEV_LABEL_COLUMN=16 + LABEL_COLUMN=12 + elif [ "$TASK" = "QNLI" ] + then + INPUT_COLUMNS=( 2 3 ) + TEST_INPUT_COLUMNS=( 2 3 ) + LABEL_COLUMN=4 + elif [ "$TASK" = "MRPC" ] + then + INPUT_COLUMNS=( 4 5 ) + TEST_INPUT_COLUMNS=( 4 5 ) + LABEL_COLUMN=1 + elif [ "$TASK" = "RTE" ] + then + INPUT_COLUMNS=( 2 3 ) + TEST_INPUT_COLUMNS=( 2 3 ) + LABEL_COLUMN=4 + elif [ "$TASK" = "STS-B" ] + then + INPUT_COLUMNS=( 8 9 ) + TEST_INPUT_COLUMNS=( 8 9 ) + LABEL_COLUMN=10 + # Following are single sentence tasks. + elif [ "$TASK" = "SST-2" ] + then + INPUT_COLUMNS=( 1 ) + TEST_INPUT_COLUMNS=( 2 ) + LABEL_COLUMN=2 + INPUT_COUNT=1 + elif [ "$TASK" = "CoLA" ] + then + INPUT_COLUMNS=( 4 ) + TEST_INPUT_COLUMNS=( 2 ) + LABEL_COLUMN=2 + INPUT_COUNT=1 + fi + + # Strip out header and filter lines that don't have expected number of fields. + rm -rf "$TASK_DATA_FOLDER/processed" + mkdir "$TASK_DATA_FOLDER/processed" + for SPLIT in $SPLITS + do + # CoLA train and dev doesn't have header. + if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]] + then + cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp"; + else + tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp"; + fi + + # Remove unformatted lines from train and dev files for QQP dataset. + if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]] + then + awk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv"; + else + cp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv"; + fi + rm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp"; + done + + # Split into input0, input1 and label + for SPLIT in $SPLITS + do + for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1))) + do + if [[ "$SPLIT" != test* ]] + then + COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]} + else + COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]} + fi + cut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE"; + done + + if [[ "$SPLIT" != test* ]] + then + if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ] + then + cut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label"; + else + cut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label"; + fi + fi + + # BPE encode. + for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1))) + do + LANG="input$INPUT_TYPE" + echo "BPE encoding $SPLIT/$LANG" + python -m examples.roberta.multiprocessing_bpe_encoder \ + --encoder-json encoder.json \ + --vocab-bpe vocab.bpe \ + --inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \ + --outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \ + --workers 60 \ + --keep-empty; + done + done + + # Remove output directory. + rm -rf "$TASK-bin" + + DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG" + TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG" + if [ "$TASK" = "MNLI" ] + then + DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG" + TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG" + fi + + # Run fairseq preprocessing: + for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1))) + do + LANG="input$INPUT_TYPE" + python preprocess.py \ + --only-source \ + --trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \ + --validpref "${DEVPREF//LANG/$LANG}" \ + --testpref "${TESTPREF//LANG/$LANG}" \ + --destdir "$TASK-bin/$LANG" \ + --workers 60 \ + --srcdict dict.txt; + done + if [[ "$TASK" != "STS-B" ]] + then + python preprocess.py \ + --only-source \ + --trainpref "$TASK_DATA_FOLDER/processed/train.label" \ + --validpref "${DEVPREF//LANG/'label'}" \ + --destdir "$TASK-bin/label" \ + --workers 60; + else + # For STS-B output range is converted to be between: [0.0, 1.0] + mkdir "$TASK-bin/label" + awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label" + awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label" + fi +done diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 4696875498..3e2fcbda7c 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -24,11 +24,16 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): from fairseq import distributed_utils, meters + prev_best = getattr(save_checkpoint, 'best', val_loss) + if val_loss is not None: + best_function = max if args.maximize_best_checkpoint_metric else min + save_checkpoint.best = best_function(val_loss, prev_best) + if args.no_save or not distributed_utils.is_master(args): return def is_better(a, b): - return a > b if args.maximize_best_checkpoint_metric else a < b + return a >= b if args.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() @@ -52,9 +57,6 @@ def is_better(a, b): ) checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints - prev_best = getattr(save_checkpoint, 'best', val_loss) - if val_loss is not None: - save_checkpoint.best = val_loss if is_better(val_loss, prev_best) else prev_best extra_state = { 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py new file mode 100644 index 0000000000..9b4a2d1815 --- /dev/null +++ b/fairseq/criterions/sentence_prediction.py @@ -0,0 +1,101 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math + +import torch +import torch.nn.functional as F + +from fairseq import utils + +from . import FairseqCriterion, register_criterion + + +@register_criterion('sentence_prediction') +class SentencePredictionCriterion(FairseqCriterion): + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--save-predictions', metavar='FILE', + help='file to save predictions to') + # fmt: on + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + features, extra = model(**sample['net_input'], features_only=True) + padding_mask = sample['net_input']['src_tokens'].eq(self.padding_idx) + + assert hasattr(model, 'classification_heads') and \ + 'sentence_classification_head' in model.classification_heads, \ + "model must provide sentence classification head for --criterion=sentence_prediction" + + logits = model.classification_heads['sentence_classification_head']( + features, + padding_mask=padding_mask, + ) + + targets = model.get_targets(sample, [logits]).view(-1) + sample_size = targets.numel() + + if not self.args.regression_target: + loss = F.nll_loss( + F.log_softmax(logits, dim=-1, dtype=torch.float32), + targets, + reduction='sum', + ) + else: + logits = logits.squeeze().float() + targets = targets.float() + loss = F.mse_loss( + logits, + targets, + reduction='sum', + ) + + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample_size, + 'sample_size': sample_size, + } + + if not self.args.regression_target: + preds = logits.max(dim=1)[1] + logging_output.update( + ncorrect=(preds == targets).sum().item() + ) + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get('loss', 0) for log in logging_outputs) + ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) + nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) + sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) + + agg_output = { + 'loss': loss_sum / sample_size / math.log(2), + 'ntokens': ntokens, + 'nsentences': nsentences, + 'sample_size': sample_size, + } + + if len(logging_outputs) > 0 and 'ncorrect' in logging_outputs[0]: + ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs) + agg_output.update(accuracy=ncorrect/nsentences) + + if sample_size != ntokens: + agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) + return agg_output diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 7ff95af472..14e9770ae9 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -14,6 +14,7 @@ from .audio.raw_audio_dataset import RawAudioDataset from .backtranslation_dataset import BacktranslationDataset from .concat_dataset import ConcatDataset +from .concat_sentences_dataset import ConcatSentencesDataset from .id_dataset import IdDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset from .language_pair_dataset import LanguagePairDataset @@ -25,13 +26,17 @@ from .noising import NoisingDataset from .numel_dataset import NumelDataset from .num_samples_dataset import NumSamplesDataset +from .offset_tokens_dataset import OffsetTokensDataset from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset from .prepend_token_dataset import PrependTokenDataset +from .raw_label_dataset import RawLabelDataset from .round_robin_zip_datasets import RoundRobinZipDatasets from .sort_dataset import SortDataset +from .strip_token_dataset import StripTokenDataset from .token_block_dataset import TokenBlockDataset from .transform_eos_dataset import TransformEosDataset from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset +from .truncate_dataset import TruncateDataset from .iterators import ( CountingIterator, @@ -44,6 +49,7 @@ 'BacktranslationDataset', 'BaseWrapperDataset', 'ConcatDataset', + 'ConcatSentencesDataset', 'CountingIterator', 'Dictionary', 'EpochBatchIterator', @@ -64,15 +70,19 @@ 'NoisingDataset', 'NumelDataset', 'NumSamplesDataset', + "OffsetTokensDataset", 'PadDataset', 'PrependTokenDataset', 'RawAudioDataset', + "RawLabelDataset", 'RightPadDataset', 'RoundRobinZipDatasets', 'ShardedIterator', 'SortDataset', + "StripTokenDataset", 'TokenBlockDataset', 'TransformEosDataset', 'TransformEosLangPairDataset', + "TruncateDataset", 'TruncatedDictionary', ] diff --git a/fairseq/data/concat_sentences_dataset.py b/fairseq/data/concat_sentences_dataset.py new file mode 100644 index 0000000000..342018f096 --- /dev/null +++ b/fairseq/data/concat_sentences_dataset.py @@ -0,0 +1,52 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch + +from . import FairseqDataset + + +class ConcatSentencesDataset(FairseqDataset): + + def __init__(self, *datasets): + super().__init__() + self.datasets = datasets + assert all(len(ds) == len(datasets[0]) for ds in datasets), \ + 'datasets must have the same length' + + def __getitem__(self, index): + return torch.cat([ds[index] for ds in self.datasets]) + + def __len__(self): + return len(self.datasets[0]) + + def collater(self, samples): + return self.datasets[0].collater(samples) + + @property + def sizes(self): + return sum(ds.sizes for ds in self.datasets) + + def num_tokens(self, index): + return sum(ds.num_tokens(index) for ds in self.datasets) + + def size(self, index): + return sum(ds.size(index) for ds in self.datasets) + + def ordered_indices(self): + return self.datasets[0].ordered_indices() + + @property + def supports_prefetch(self): + return any( + getattr(ds, 'supports_prefetch', False) for ds in self.datasets + ) + + def prefetch(self, indices): + for ds in self.datasets: + if getattr(ds, 'supports_prefetch', False): + ds.prefetch(indices) diff --git a/fairseq/data/offset_tokens_dataset.py b/fairseq/data/offset_tokens_dataset.py new file mode 100644 index 0000000000..7a947f66ed --- /dev/null +++ b/fairseq/data/offset_tokens_dataset.py @@ -0,0 +1,18 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from . import BaseWrapperDataset + + +class OffsetTokensDataset(BaseWrapperDataset): + + def __init__(self, dataset, offset): + super().__init__(dataset) + self.offset = offset + + def __getitem__(self, idx): + return self.dataset[idx] + self.offset diff --git a/fairseq/data/raw_label_dataset.py b/fairseq/data/raw_label_dataset.py new file mode 100644 index 0000000000..5f7cc0e43c --- /dev/null +++ b/fairseq/data/raw_label_dataset.py @@ -0,0 +1,26 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch + +from . import FairseqDataset + + +class RawLabelDataset(FairseqDataset): + + def __init__(self, labels): + super().__init__() + self.labels = labels + + def __getitem__(self, index): + return self.labels[index] + + def __len__(self): + return len(self.labels) + + def collater(self, samples): + return torch.tensor(samples) diff --git a/fairseq/data/strip_token_dataset.py b/fairseq/data/strip_token_dataset.py new file mode 100644 index 0000000000..eeb48ae600 --- /dev/null +++ b/fairseq/data/strip_token_dataset.py @@ -0,0 +1,19 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from . import BaseWrapperDataset + + +class StripTokenDataset(BaseWrapperDataset): + + def __init__(self, dataset, id_to_strip): + super().__init__(dataset) + self.id_to_strip = id_to_strip + + def __getitem__(self, index): + item = self.dataset[index] + return item[item.ne(self.id_to_strip)] diff --git a/fairseq/data/truncate_dataset.py b/fairseq/data/truncate_dataset.py new file mode 100644 index 0000000000..0e350e407f --- /dev/null +++ b/fairseq/data/truncate_dataset.py @@ -0,0 +1,32 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import numpy as np + +from . import BaseWrapperDataset + + +class TruncateDataset(BaseWrapperDataset): + + def __init__(self, dataset, truncation_length): + super().__init__(dataset) + self.truncation_length = truncation_length + self.dataset = dataset + + def __getitem__(self, index): + item = self.dataset[index] + item_len = item.size(0) + if item_len > self.truncation_length: + item = item[:self.truncation_length] + return item + + @property + def sizes(self): + return np.minimum(self.dataset.sizes, self.truncation_length) + + def __len__(self): + return len(self.dataset) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index e7c6d5b7fc..c8794b2607 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -134,6 +134,15 @@ def upgrade_state_dict_named(self, state_dict, name): ].size(0) self.register_classification_head(head_name, num_classes, inner_dim) + # Copy any newly-added classification heads into the state dict + # with their current weights. + if hasattr(self, 'classification_heads'): + cur_state = self.classification_heads.state_dict() + for k, v in cur_state.items(): + if prefix + 'classification_heads.' + k not in state_dict: + print('Overwriting', prefix + 'classification_heads.' + k) + state_dict[prefix + 'classification_heads.' + k] = v + class RobertaLMHead(nn.Module): """Head for masked language modeling.""" diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py new file mode 100644 index 0000000000..0f54ef81f1 --- /dev/null +++ b/fairseq/tasks/sentence_prediction.py @@ -0,0 +1,212 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import os + +import numpy as np + +from fairseq.data import ( + ConcatSentencesDataset, + data_utils, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumSamplesDataset, + NumelDataset, + OffsetTokensDataset, + PrependTokenDataset, + RawLabelDataset, + RightPadDataset, + SortDataset, + StripTokenDataset, + TruncateDataset, +) + +from . import FairseqTask, register_task + + +@register_task('sentence_prediction') +class SentencePredictionTask(FairseqTask): + """ + Sentence (or sentence pair) prediction (classification or regression) task. + + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument('data', metavar='FILE', + help='file prefix for data') + parser.add_argument('--max-positions', type=int, default=512, + help='max input length') + parser.add_argument('--num-classes', type=int, default=-1, + help='number of classes') + parser.add_argument('--init-token', type=int, default=None, + help='add token at the beginning of each batch item') + parser.add_argument('--separator-token', type=int, default=None, + help='add separator token between inputs') + parser.add_argument('--regression-target', action='store_true', default=False) + parser.add_argument('--no-shuffle', action='store_true', default=False) + parser.add_argument('--truncate-sequence', action='store_true', default=False, + help='Truncate sequence to max_sequence_length') + + def __init__(self, args, data_dictionary, label_dictionary): + super().__init__(args) + self.dictionary = data_dictionary + self.label_dictionary = label_dictionary + + @classmethod + def load_dictionary(cls, args, filename, source=True): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + dictionary = Dictionary.load(filename) + dictionary.add_symbol('') + return dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + assert args.num_classes > 0, 'Must set --num-classes' + + args.tokens_per_sample = args.max_positions + + # load data dictionary + data_dict = cls.load_dictionary( + args, + os.path.join(args.data, 'input0', 'dict.txt'), + source=True, + ) + print('| [input] dictionary: {} types'.format(len(data_dict))) + + label_dict = None + if not args.regression_target: + # load label dictionary + label_dict = cls.load_dictionary( + args, + os.path.join(args.data, 'label', 'dict.txt'), + source=False, + ) + print('| [label] dictionary: {} types'.format(len(label_dict))) + else: + label_dict = data_dict + return SentencePredictionTask(args, data_dict, label_dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test).""" + def get_path(type, split): + return os.path.join(self.args.data, type, split) + + def make_dataset(type, dictionary): + split_path = get_path(type, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + self.args.dataset_impl, + combine=combine, + ) + return dataset + + input0 = make_dataset('input0', self.source_dictionary) + assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split)) + input1 = make_dataset('input1', self.source_dictionary) + + if self.args.init_token is not None: + input0 = PrependTokenDataset(input0, self.args.init_token) + + if input1 is None: + src_tokens = input0 + else: + if self.args.separator_token is not None: + input1 = PrependTokenDataset(input1, self.args.separator_token) + + src_tokens = ConcatSentencesDataset(input0, input1) + + with data_utils.numpy_seed(self.args.seed): + shuffle = np.random.permutation(len(src_tokens)) + + if self.args.truncate_sequence: + src_tokens = TruncateDataset(src_tokens, self.args.max_positions) + + dataset = { + 'id': IdDataset(), + 'net_input': { + 'src_tokens': RightPadDataset( + src_tokens, + pad_idx=self.source_dictionary.pad(), + ), + 'src_lengths': NumelDataset(src_tokens, reduce=False), + }, + 'nsentences': NumSamplesDataset(), + 'ntokens': NumelDataset(src_tokens, reduce=True), + } + + if not self.args.regression_target: + label_dataset = make_dataset('label', self.target_dictionary) + if label_dataset is not None: + dataset.update( + target=OffsetTokensDataset( + StripTokenDataset( + label_dataset, + id_to_strip=self.target_dictionary.eos(), + ), + offset=-self.target_dictionary.nspecial, + ) + ) + else: + label_path = f"{get_path('label', split)}.label" + if os.path.exists(label_path): + dataset.update( + target=RawLabelDataset([ + float(x.strip()) for x in open(label_path).readlines() + ]) + ) + + nested_dataset = NestedDictionaryDataset( + dataset, + sizes=[src_tokens.sizes], + ) + + if self.args.no_shuffle: + dataset = nested_dataset + else: + dataset = SortDataset( + nested_dataset, + # shuffle + sort_order=[shuffle], + ) + + print(f"| Loaded {split} with #samples: {len(dataset)}") + + self.datasets[split] = dataset + return self.datasets[split] + + def build_model(self, args): + from fairseq import models + model = models.build_model(args, self) + + model.register_classification_head( + 'sentence_classification_head', + num_classes=self.args.num_classes, + ) + + return model + + def max_positions(self): + return self.args.max_positions + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.label_dictionary diff --git a/train.py b/train.py index 9531829625..c4a95023c3 100644 --- a/train.py +++ b/train.py @@ -130,7 +130,7 @@ def train(args, trainer, task, epoch_itr): for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue # these are already logged above - if 'loss' in k: + if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) @@ -236,16 +236,20 @@ def validate(args, trainer, task, epoch_itr, subsets): extra_meters[k].update(v) # log validation stats - stats = get_valid_stats(trainer) + stats = get_valid_stats(trainer, args, extra_meters) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats[args.best_checkpoint_metric].avg) + valid_losses.append( + stats[args.best_checkpoint_metric].avg + if args.best_checkpoint_metric == 'loss' + else stats[args.best_checkpoint_metric] + ) return valid_losses -def get_valid_stats(trainer): +def get_valid_stats(trainer, args, extra_meters=None): stats = collections.OrderedDict() stats['loss'] = trainer.get_meter('valid_loss') if trainer.get_meter('valid_nll_loss').count > 0: @@ -256,8 +260,23 @@ def get_valid_stats(trainer): stats['ppl'] = utils.get_perplexity(nll_loss.avg) stats['num_updates'] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, 'best'): - stats['best_loss'] = min( - checkpoint_utils.save_checkpoint.best, stats['loss'].avg) + key = f'best_{args.best_checkpoint_metric}' + best_function = max if args.maximize_best_checkpoint_metric else min + + current_metric = None + if args.best_checkpoint_metric == 'loss': + current_metric = stats['loss'].avg + elif args.best_checkpoint_metric in extra_meters: + current_metric = extra_meters[args.best_checkpoint_metric].avg + elif args.best_checkpoint_metric in stats: + current_metric = stats[args.best_checkpoint_metric] + else: + raise ValueError("best_checkpoint_metric not found in logs") + + stats[key] = best_function( + checkpoint_utils.save_checkpoint.best, + current_metric, + ) return stats