-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathconvert_dataset.py
91 lines (74 loc) · 2.99 KB
/
convert_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Convert Mctx and Mgen Data to xTuring Style."""
import argparse
from utils import load_dataset, write_dataset
MCTX_INSTR_MAP = {
"nq": "Given the ['question', 'context'], predict the most helpful sentences in the context.",
"tqa": "Given the ['question', 'context'], predict the most helpful sentences in the context.",
"fever": "Given the ['claim', 'context'], predict the most helpful sentences in the context.",
"wow": "Given the ['query', 'context'], predict the most helpful sentences in the context.",
"hotpotqa": "Given the ['question', 'context'], predict the most helpful sentences in the context.",
"eli5": "Given the ['question', 'context'], predict the most helpful sentences in the context.",
}
MGEN_INSTR_MAP = {
"nq": "Given the ['context', 'question'], predict the answer to the question.",
"tqa": "Given the ['context', 'question'], predict the answer to the question.",
"fever": "Given the ['context', 'claim'], predict the judgement to the claim.",
"wow": "Given the ['context', 'query'], predict the response to the query.",
"hotpotqa": "Given the ['context', 'question'], predict the answer to the question.",
"eli5": "Given the ['context', 'question'], predict the answer to the question.",
}
KEYWORD_MAP = {
"nq": "question:",
"tqa": "question:",
"fever": "claim:",
"wow": "query:",
"hotpotqa": "question:",
"eli5": "question:",
}
def main():
src_data = load_dataset(args.input_data_path)
if (args.dataset_type == "mctx"):
instruct_text = MCTX_INSTR_MAP[args.dataset_name]
keyword = KEYWORD_MAP[args.dataset_name]
else:
instruct_text = MGEN_INSTR_MAP[args.dataset_name]
keyword = "context:"
tgt_data = []
for i,ex in enumerate(src_data):
if keyword not in ex["input"]: print(i, " | ", ex["input"])
tgt_data.append({
"instruction": instruct_text,
"text": ex["input"][ex["input"].index(keyword): ],
"target": ex["output"],
})
assert args.output_path.endswith(".jsonl")
write_dataset(args.output_path, tgt_data)
if args.print_example:
for k,v in tgt_data[0].items():
print(f"=== {k.upper()} ===")
print(v)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_data_path", type=str, required=True,
help="Path to the original data.",
)
parser.add_argument(
"--output_path", type=str, required=True,
help="Directory to the output data.",
)
parser.add_argument(
"--dataset_name", type=str, required=True,
choices=["nq", "tqa", "fever", "wow", "hotpotqa", "eli5"],
help="Name of the dataset.",
)
parser.add_argument(
"--dataset_type", type=str, required=True,
choices=["mctx", "mgen"],
help="Type of the dataset.",
)
parser.add_argument(
"--print_example", action="store_true",
)
args = parser.parse_args()
main()