forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding glue data preprocessing scripts (facebookresearch#771)
Summary: 1) Added glue data pre-processing script. 2) updated README with usage. TODO: 1) releasing fairseq dictionary and remove hardcoded path. 2) remove hard-coded path for bpe-encoding, myleott what do you recommend for above TODOs? Pull Request resolved: fairinternal/fairseq-py#771 Reviewed By: myleott Differential Revision: D16547679 Pulled By: myleott fbshipit-source-id: 6a6562d9b6215523d048fdf3daee63ffac21e231
- Loading branch information
1 parent
33597e5
commit 138dc8e
Showing
14 changed files
with
892 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import contextlib | ||
import sys | ||
|
||
from collections import Counter | ||
from multiprocessing import Pool | ||
|
||
from fairseq.data.encoders.gpt2_bpe import get_encoder | ||
|
||
|
||
def main(): | ||
""" | ||
Helper script to encode raw text | ||
with the GPT-2 BPE using multiple processes. | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--encoder-json", | ||
help='path to encoder.json', | ||
) | ||
parser.add_argument( | ||
"--vocab-bpe", | ||
type=str, | ||
help='path to vocab.bpe', | ||
) | ||
parser.add_argument( | ||
"--inputs", | ||
nargs="+", | ||
default=['-'], | ||
help="input files to filter/encode", | ||
) | ||
parser.add_argument( | ||
"--outputs", | ||
nargs="+", | ||
default=['-'], | ||
help="path to save encoded outputs", | ||
) | ||
parser.add_argument( | ||
"--keep-empty", | ||
action="store_true", | ||
help="keep empty lines", | ||
) | ||
parser.add_argument("--workers", type=int, default=20) | ||
args = parser.parse_args() | ||
|
||
assert len(args.inputs) == len(args.outputs), \ | ||
"number of input and output paths should match" | ||
|
||
with contextlib.ExitStack() as stack: | ||
inputs = [ | ||
stack.enter_context(open(input, "r", encoding="utf-8")) | ||
if input != "-" else sys.stdin | ||
for input in args.inputs | ||
] | ||
outputs = [ | ||
stack.enter_context(open(output, "w", encoding="utf-8")) | ||
if output != "-" else sys.stdout | ||
for output in args.outputs | ||
] | ||
|
||
encoder = MultiprocessingEncoder(args) | ||
pool = Pool(args.workers, initializer=encoder.initializer) | ||
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) | ||
|
||
stats = Counter() | ||
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): | ||
if filt == "PASS": | ||
for enc_line, output_h in zip(enc_lines, outputs): | ||
print(enc_line, file=output_h) | ||
else: | ||
stats["num_filtered_" + filt] += 1 | ||
if i % 10000 == 0: | ||
print("processed {} lines".format(i), file=sys.stderr) | ||
|
||
for k, v in stats.most_common(): | ||
print("[{}] filtered {} lines".format(k, v), file=sys.stderr) | ||
|
||
|
||
class MultiprocessingEncoder(object): | ||
|
||
def __init__(self, args): | ||
self.args = args | ||
|
||
def initializer(self): | ||
global bpe | ||
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe) | ||
|
||
def encode(self, line): | ||
global bpe | ||
ids = bpe.encode(line) | ||
return list(map(str, ids)) | ||
|
||
def decode(self, tokens): | ||
global bpe | ||
return bpe.decode(tokens) | ||
|
||
def encode_lines(self, lines): | ||
""" | ||
Encode a set of lines. All lines will be encoded together. | ||
""" | ||
enc_lines = [] | ||
for line in lines: | ||
line = line.strip() | ||
if len(line) == 0 and not self.args.keep_empty: | ||
return ["EMPTY", None] | ||
tokens = self.encode(line) | ||
enc_lines.append(" ".join(tokens)) | ||
return ["PASS", enc_lines] | ||
|
||
def decode_lines(self, lines): | ||
dec_lines = [] | ||
for line in lines: | ||
tokens = map(int, line.strip().split()) | ||
dec_lines.append(self.decode(tokens)) | ||
return ["PASS", dec_lines] | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
#!/bin/bash | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
|
||
# raw glue data as downloaded by glue download script (https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) | ||
if [[ $# -ne 2 ]]; then | ||
echo "Run as following:" | ||
echo "./examples/roberta/preprocess_GLUE_tasks.sh <glud_data_folder> <task_name>" | ||
exit 1 | ||
fi | ||
|
||
GLUE_DATA_FOLDER=$1 | ||
|
||
# download bpe encoder.json, vocabulary and fairseq dictionary | ||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' | ||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' | ||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt' | ||
|
||
TASKS=$2 # QQP | ||
|
||
if [ "$TASKS" = "ALL" ] | ||
then | ||
TASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA" | ||
fi | ||
|
||
for TASK in $TASKS | ||
do | ||
echo "Preprocessing $TASK" | ||
|
||
TASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK" | ||
echo "Raw data as downloaded from glue website: $TASK_DATA_FOLDER" | ||
|
||
SPLITS="train dev test" | ||
INPUT_COUNT=2 | ||
if [ "$TASK" = "QQP" ] | ||
then | ||
INPUT_COLUMNS=( 4 5 ) | ||
TEST_INPUT_COLUMNS=( 2 3 ) | ||
LABEL_COLUMN=6 | ||
elif [ "$TASK" = "MNLI" ] | ||
then | ||
SPLITS="train dev_matched dev_mismatched test_matched test_mismatched" | ||
INPUT_COLUMNS=( 9 10 ) | ||
TEST_INPUT_COLUMNS=( 9 10 ) | ||
DEV_LABEL_COLUMN=16 | ||
LABEL_COLUMN=12 | ||
elif [ "$TASK" = "QNLI" ] | ||
then | ||
INPUT_COLUMNS=( 2 3 ) | ||
TEST_INPUT_COLUMNS=( 2 3 ) | ||
LABEL_COLUMN=4 | ||
elif [ "$TASK" = "MRPC" ] | ||
then | ||
INPUT_COLUMNS=( 4 5 ) | ||
TEST_INPUT_COLUMNS=( 4 5 ) | ||
LABEL_COLUMN=1 | ||
elif [ "$TASK" = "RTE" ] | ||
then | ||
INPUT_COLUMNS=( 2 3 ) | ||
TEST_INPUT_COLUMNS=( 2 3 ) | ||
LABEL_COLUMN=4 | ||
elif [ "$TASK" = "STS-B" ] | ||
then | ||
INPUT_COLUMNS=( 8 9 ) | ||
TEST_INPUT_COLUMNS=( 8 9 ) | ||
LABEL_COLUMN=10 | ||
# Following are single sentence tasks. | ||
elif [ "$TASK" = "SST-2" ] | ||
then | ||
INPUT_COLUMNS=( 1 ) | ||
TEST_INPUT_COLUMNS=( 2 ) | ||
LABEL_COLUMN=2 | ||
INPUT_COUNT=1 | ||
elif [ "$TASK" = "CoLA" ] | ||
then | ||
INPUT_COLUMNS=( 4 ) | ||
TEST_INPUT_COLUMNS=( 2 ) | ||
LABEL_COLUMN=2 | ||
INPUT_COUNT=1 | ||
fi | ||
|
||
# Strip out header and filter lines that don't have expected number of fields. | ||
rm -rf "$TASK_DATA_FOLDER/processed" | ||
mkdir "$TASK_DATA_FOLDER/processed" | ||
for SPLIT in $SPLITS | ||
do | ||
# CoLA train and dev doesn't have header. | ||
if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]] | ||
then | ||
cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp"; | ||
else | ||
tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp"; | ||
fi | ||
|
||
# Remove unformatted lines from train and dev files for QQP dataset. | ||
if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]] | ||
then | ||
awk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv"; | ||
else | ||
cp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv"; | ||
fi | ||
rm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp"; | ||
done | ||
|
||
# Split into input0, input1 and label | ||
for SPLIT in $SPLITS | ||
do | ||
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1))) | ||
do | ||
if [[ "$SPLIT" != test* ]] | ||
then | ||
COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]} | ||
else | ||
COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]} | ||
fi | ||
cut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE"; | ||
done | ||
|
||
if [[ "$SPLIT" != test* ]] | ||
then | ||
if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ] | ||
then | ||
cut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label"; | ||
else | ||
cut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label"; | ||
fi | ||
fi | ||
|
||
# BPE encode. | ||
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1))) | ||
do | ||
LANG="input$INPUT_TYPE" | ||
echo "BPE encoding $SPLIT/$LANG" | ||
python -m examples.roberta.multiprocessing_bpe_encoder \ | ||
--encoder-json encoder.json \ | ||
--vocab-bpe vocab.bpe \ | ||
--inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \ | ||
--outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \ | ||
--workers 60 \ | ||
--keep-empty; | ||
done | ||
done | ||
|
||
# Remove output directory. | ||
rm -rf "$TASK-bin" | ||
|
||
DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG" | ||
TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG" | ||
if [ "$TASK" = "MNLI" ] | ||
then | ||
DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG" | ||
TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG" | ||
fi | ||
|
||
# Run fairseq preprocessing: | ||
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1))) | ||
do | ||
LANG="input$INPUT_TYPE" | ||
python preprocess.py \ | ||
--only-source \ | ||
--trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \ | ||
--validpref "${DEVPREF//LANG/$LANG}" \ | ||
--testpref "${TESTPREF//LANG/$LANG}" \ | ||
--destdir "$TASK-bin/$LANG" \ | ||
--workers 60 \ | ||
--srcdict dict.txt; | ||
done | ||
if [[ "$TASK" != "STS-B" ]] | ||
then | ||
python preprocess.py \ | ||
--only-source \ | ||
--trainpref "$TASK_DATA_FOLDER/processed/train.label" \ | ||
--validpref "${DEVPREF//LANG/'label'}" \ | ||
--destdir "$TASK-bin/label" \ | ||
--workers 60; | ||
else | ||
# For STS-B output range is converted to be between: [0.0, 1.0] | ||
mkdir "$TASK-bin/label" | ||
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label" | ||
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label" | ||
fi | ||
done |
Oops, something went wrong.