1
1
import git
2
2
import os
3
+ import re
3
4
import sys
4
5
import traceback
5
6
from datasets import load_dataset
11
12
Files ,
12
13
RUN_PYTEST_LOG_DIR ,
13
14
RepoInstance ,
15
+ SimpleInstance ,
14
16
)
15
17
from commit0 .harness .spec import make_spec
16
18
from commit0 .harness .utils import (
@@ -46,7 +48,7 @@ def main(
46
48
Tests are run either locally through docker
47
49
or remotely through Modal.
48
50
"""
49
- dataset : Iterator [RepoInstance ] = load_dataset (dataset_name , split = dataset_split ) # type: ignore
51
+ dataset : Iterator [Union [ RepoInstance , SimpleInstance ] ] = load_dataset (dataset_name , split = dataset_split ) # type: ignore
50
52
spec = None
51
53
example = None
52
54
repo_name = None
@@ -56,10 +58,13 @@ def main(
56
58
if "swe" in dataset_name .lower ():
57
59
repo_name = example ["instance_id" ]
58
60
dataset_type = "swebench"
61
+ elif "humaneval" in dataset_name .lower ():
62
+ repo_name = example ["instance_id" ]
63
+ dataset_type = "simple"
59
64
else :
60
65
repo_name = example ["repo" ].split ("/" )[- 1 ]
61
66
dataset_type = "commit0"
62
- if repo_name in os .path .basename (repo_or_repo_dir ):
67
+ if repo_name in os .path .basename (repo_or_repo_dir ) or repo_or_repo_dir . endswith ( repo_name ) :
63
68
spec = make_spec (example , dataset_type )
64
69
break
65
70
assert spec is not None , "No spec available"
@@ -73,46 +78,61 @@ def main(
73
78
log_file = log_dir / "run_pytest.log"
74
79
logger = setup_logger (repo_name , log_file , verbose = verbose )
75
80
76
- try :
77
- local_repo = git .Repo (repo_or_repo_dir )
78
- logger .info (f"Loaded a git repo from { repo_or_repo_dir } " )
79
- except (git .exc .NoSuchPathError , git .exc .InvalidGitRepositoryError ): # type: ignore
80
- repo_dir = os .path .join (base_dir , repo_name )
81
- logger .error (f"{ repo_or_repo_dir } is not a git dir, trying { repo_dir } again" )
81
+ if dataset_type != "simple" : # if dataset_type is not simple, load git repo
82
82
try :
83
- local_repo = git .Repo (repo_dir )
84
- logger .info (f"Retried succeeded. Loaded a git repo from { repo_dir } " )
85
- except git .exc .NoSuchPathError : # type: ignore
86
- raise Exception (
87
- f"{ repo_dir } and { repo_or_repo_dir } are not git directories.\n Usage: commit0 test {{repo_dir}} {{branch}} {{test_ids}}"
88
- )
89
- except Exception as e :
90
- raise e
91
- commit_id = ""
92
- if branch == "reference" :
93
- commit_id = example ["reference_commit" ]
94
- else :
95
- # Check if it's a local branch
96
- if branch in local_repo .branches :
97
- commit_id = local_repo .commit (branch ).hexsha
83
+ local_repo = git .Repo (repo_or_repo_dir )
84
+ logger .info (f"Loaded a git repo from { repo_or_repo_dir } " )
85
+ except (git .exc .NoSuchPathError , git .exc .InvalidGitRepositoryError ): # type: ignore
86
+ repo_dir = os .path .join (base_dir , repo_name )
87
+ logger .error (f"{ repo_or_repo_dir } is not a git dir, trying { repo_dir } again" )
88
+ try :
89
+ local_repo = git .Repo (repo_dir )
90
+ logger .info (f"Retried succeeded. Loaded a git repo from { repo_dir } " )
91
+ except git .exc .NoSuchPathError : # type: ignore
92
+ raise Exception (
93
+ f"{ repo_dir } and { repo_or_repo_dir } are not git directories.\n Usage: commit0 test {{repo_dir}} {{branch}} {{test_ids}}"
94
+ )
95
+ except Exception as e :
96
+ raise e
97
+ commit_id = ""
98
+ if branch == "reference" :
99
+ commit_id = example ["reference_commit" ]
98
100
else :
99
- found_remote_branch = False
100
- for remote in local_repo .remotes :
101
- remote .fetch () # Fetch latest updates from each remote
101
+ # Check if it's a local branch
102
+ if branch in local_repo .branches :
103
+ commit_id = local_repo .commit (branch ).hexsha
104
+ else :
105
+ found_remote_branch = False
106
+ for remote in local_repo .remotes :
107
+ remote .fetch () # Fetch latest updates from each remote
102
108
103
- # Check if the branch exists in this remote
104
- for ref in remote .refs :
105
- if (
106
- ref .remote_head == branch
107
- ): # Compare branch name without remote prefix
108
- commit_id = local_repo .commit (ref .name ).hexsha
109
- found_remote_branch = True
110
- break # Branch found, no need to keep checking this remote
111
- if found_remote_branch :
112
- break # Stop checking other remotes if branch is found
113
- if not found_remote_branch :
114
- raise Exception (f"Branch { branch } does not exist locally or remotely." )
115
- if "swe" in dataset_name .lower ():
109
+ # Check if the branch exists in this remote
110
+ for ref in remote .refs :
111
+ if (
112
+ ref .remote_head == branch
113
+ ): # Compare branch name without remote prefix
114
+ commit_id = local_repo .commit (ref .name ).hexsha
115
+ found_remote_branch = True
116
+ break # Branch found, no need to keep checking this remote
117
+ if found_remote_branch :
118
+ break # Stop checking other remotes if branch is found
119
+ if not found_remote_branch :
120
+ raise Exception (f"Branch { branch } does not exist locally or remotely." )
121
+ if dataset_type == "simple" :
122
+ if branch == "reference" :
123
+ patch = example ["prompt" ] + "\n \n " + example ["canonical_solution" ] + "\n \n " + example ["test" ]
124
+ else :
125
+ solution = open (test_ids ).read ()
126
+ pattern = r"```python\n(.*?)```"
127
+ matches = re .finditer (pattern , solution , re .DOTALL )
128
+ matches = [match .group (1 ).strip () for match in matches ]
129
+ if len (matches ) > 0 :
130
+ solution = "\n \n " .join (matches )
131
+ else :
132
+ solution = example ["prompt" ] + "\n \n " + solution
133
+ patch = solution + "\n \n " + example ["test" ]
134
+ patch = patch + "\n \n " + f"check({ example ['entry_point' ]} )"
135
+ elif "swe" in dataset_name .lower ():
116
136
if branch == "reference" :
117
137
patch = example ["test" ]["patch" ] + "\n \n " + example ["test" ]["test_patch" ]
118
138
else :
@@ -127,12 +147,15 @@ def main(
127
147
patch_file = Path (log_dir / "patch.diff" )
128
148
patch_file .write_text (patch , encoding = "utf-8" , errors = "ignore" )
129
149
130
- # make eval file
131
- if coverage :
132
- coverage_text = f" --cov={ example ['src_dir' ]} --cov-branch --cov-report json"
150
+ if dataset_type != "simple" :
151
+ # make eval file
152
+ if coverage :
153
+ coverage_text = f" --cov={ example ['src_dir' ]} --cov-branch --cov-report json"
154
+ else :
155
+ coverage_text = ""
156
+ eval_script = spec .eval_script .format (test_ids = test_ids , coverage = coverage_text )
133
157
else :
134
- coverage_text = ""
135
- eval_script = spec .eval_script .format (test_ids = test_ids , coverage = coverage_text )
158
+ eval_script = spec .eval_script
136
159
eval_file = Path (log_dir / "eval.sh" )
137
160
eval_file .write_text (eval_script )
138
161
0 commit comments