-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathproblem_filtering.py
106 lines (86 loc) · 3.39 KB
/
problem_filtering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import json
from typing import List, Dict, Tuple
from dataclasses import dataclass
from str2bool import str2bool
@dataclass
class ProcessedItem:
prompt: str
completion: str
def parse_judgement(judgement_text: str) -> str:
judgement_text = judgement_text.lower()
if judgement_text.startswith("perfect") or "perfect" in judgement_text:
return "perfect"
elif judgement_text.startswith("acceptable") or "acceptable" in judgement_text:
return "acceptable"
return "bad"
def load_and_process_file(file_path: str) -> Tuple[List[Dict], List[str]]:
items = []
ratings = []
with open(file_path, encoding="utf-8") as f:
for line in f.readlines():
json_obj = json.loads(line)
judgement = parse_judgement(json_obj["judgement"])
assert judgement in ["perfect", "acceptable", "bad"]
ratings.append(judgement)
items.append(json_obj)
return items, ratings
def process_items(items: List[Dict], ratings_list: List[List[str]],
only_perfect: bool) -> List[ProcessedItem]:
processed_items = []
n_rewards = len(ratings_list)
for idx in range(len(items)):
all_perfect = all(
ratings_list[reward_idx][idx] == "perfect"
for reward_idx in range(n_rewards)
)
has_perfect_no_bad = (
any(ratings_list[reward_idx][idx] == "perfect"
for reward_idx in range(n_rewards)) and
all(ratings_list[reward_idx][idx] != "bad"
for reward_idx in range(n_rewards))
)
if (only_perfect and all_perfect) or (not only_perfect and has_perfect_no_bad):
processed_items.append(ProcessedItem(
prompt=items[idx]["prompt"],
completion=items[idx]["rationale_and_problem"]
))
return processed_items
def main():
parser = argparse.ArgumentParser(description='Process and filter completion data')
parser.add_argument('--template', type=str,
required=True,
help='Template for input files')
parser.add_argument('--output_path', type=str,
required=True,
help='Path for output file')
parser.add_argument('--only_perfect', type=str2bool,
default=True,
help='Only include items rated as perfect by all rewards')
parser.add_argument('--n_rewards', type=int, default=2,
help='Number of reward models')
args = parser.parse_args()
items_list = []
ratings_list = []
# Load and process files for each reward model
for reward_idx in range(args.n_rewards):
file_path = args.template.format(reward_idx)
items, ratings = load_and_process_file(file_path)
if reward_idx == 0:
items_list = items
ratings_list.append(ratings)
# Verify consistency
assert all(len(ratings) == len(items_list)
for ratings in ratings_list)
# Process items
processed_items = process_items(
items_list, ratings_list,
args.only_perfect
)
final_items = [vars(item) for item in processed_items]
# Save results
with open(args.output_path, "w", encoding="utf-8") as f:
for item in final_items:
f.write(json.dumps(item) + "\n")
if __name__ == "__main__":
main()