Skip to content

Commit 38205e6

Browse files
committed
added humaneval
1 parent f6ea9ed commit 38205e6

File tree

7 files changed

+136
-65
lines changed

7 files changed

+136
-65
lines changed

commit0/cli.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def setup(
118118
) -> None:
119119
"""Commit0 clone a repo split."""
120120
check_commit0_path()
121-
if "commit0" in dataset_name.lower():
121+
if "commit0" in dataset_name.split('/')[-1].lower():
122122
check_valid(repo_split, SPLIT)
123123

124124
base_dir = str(Path(base_dir).resolve())
@@ -169,7 +169,7 @@ def build(
169169
check_commit0_path()
170170

171171
commit0_config = read_commit0_config_file(commit0_config_file)
172-
if "commit0" in commit0_config["dataset_name"].lower():
172+
if "commit0" in commit0_config["dataset_name"].split('/')[-1].lower():
173173
check_valid(commit0_config["repo_split"], SPLIT)
174174

175175
typer.echo(
@@ -251,16 +251,20 @@ def test(
251251
commit0_config = read_commit0_config_file(commit0_config_file)
252252
if repo_or_repo_path.endswith("/"):
253253
repo_or_repo_path = repo_or_repo_path[:-1]
254-
if "commit0" in commit0_config["dataset_name"].lower():
254+
if "commit0" in commit0_config["dataset_name"].split('/')[-1].lower():
255255
check_valid(repo_or_repo_path.split("/")[-1], SPLIT)
256256

257257
if reference:
258258
branch = "reference"
259-
if branch is None and not reference:
260-
git_path = os.path.join(
261-
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
262-
)
263-
branch = get_active_branch(git_path)
259+
else:
260+
if "humaneval" not in commit0_config["dataset_name"].split('/')[-1].lower():
261+
if branch is None and not reference:
262+
git_path = os.path.join(
263+
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
264+
)
265+
branch = get_active_branch(git_path)
266+
else:
267+
branch = test_ids
264268

265269
if stdin:
266270
# Read test names from stdin
@@ -317,7 +321,7 @@ def evaluate(
317321
branch = "reference"
318322

319323
commit0_config = read_commit0_config_file(commit0_config_file)
320-
if "commit0" in commit0_config["dataset_name"].lower():
324+
if "commit0" in commit0_config["dataset_name"].split('/')[-1].lower():
321325
check_valid(commit0_config["repo_split"], SPLIT)
322326

323327
typer.echo(f"Evaluating repository split: {commit0_config['repo_split']}")
@@ -393,7 +397,7 @@ def save(
393397
"""Save Commit0 split you choose in Setup Stage to GitHub."""
394398
check_commit0_path()
395399
commit0_config = read_commit0_config_file(commit0_config_file)
396-
if "commit0" in commit0_config["dataset_name"].lower():
400+
if "commit0" in commit0_config["dataset_name"].split('/')[-1].lower():
397401
check_valid(commit0_config["repo_split"], SPLIT)
398402

399403
typer.echo(f"Saving repository split: {commit0_config['repo_split']}")

commit0/harness/build.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datasets import load_dataset
55
from typing import Iterator
66

7-
from commit0.harness.constants import RepoInstance, SPLIT
7+
from commit0.harness.constants import RepoInstance, SimpleInstance, SPLIT
88
from commit0.harness.docker_build import build_repo_images
99
from commit0.harness.spec import make_spec
1010

@@ -17,23 +17,25 @@
1717
def main(
1818
dataset_name: str,
1919
dataset_split: str,
20-
repo_split: str,
20+
split: str,
2121
num_workers: int,
2222
verbose: int,
2323
) -> None:
24-
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
24+
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(dataset_name, split=dataset_split) # type: ignore
2525
specs = []
2626
if "swe" in dataset_name.lower():
2727
dataset_type = "swebench"
28+
elif "humaneval" in dataset_name.lower():
29+
dataset_type = "simple"
2830
else:
2931
dataset_type = "commit0"
3032
for example in dataset:
31-
repo_name = example["repo"].split("/")[-1]
32-
if "swe" in dataset_name.lower():
33-
if repo_split != "all" and repo_split not in example["instance_id"]:
33+
if "swe" in dataset_name.lower() or dataset_type == "simple":
34+
if split != "all" and split not in example["instance_id"]:
3435
continue
3536
else:
36-
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
37+
repo_name = example["repo"].split("/")[-1]
38+
if split != "all" and repo_name not in SPLIT[split]:
3739
continue
3840
spec = make_spec(example, dataset_type)
3941
specs.append(spec)

commit0/harness/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ class RepoInstance(TypedDict):
1313
src_dir: str
1414

1515

16+
class SimpleInstance(TypedDict):
17+
instance_id: str
18+
prompt: str
19+
canonical_solution: str
20+
test: str
21+
entry_point: str
22+
23+
1624
class Files(TypedDict):
1725
eval_script: Dict[str, Path]
1826
patch: Dict[str, Path]

commit0/harness/execution_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(
102102
self.client = docker.from_env()
103103
self.container = create_container(
104104
client=self.client,
105-
image_name=spec.repo_image_tag,
105+
image_name=spec.repo_image_key,
106106
container_name=spec.get_container_name(),
107107
nano_cpus=num_cpus,
108108
logger=logger,

commit0/harness/run_pytest_ids.py

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import git
22
import os
3+
import re
34
import sys
45
import traceback
56
from datasets import load_dataset
@@ -11,6 +12,7 @@
1112
Files,
1213
RUN_PYTEST_LOG_DIR,
1314
RepoInstance,
15+
SimpleInstance,
1416
)
1517
from commit0.harness.spec import make_spec
1618
from commit0.harness.utils import (
@@ -46,7 +48,7 @@ def main(
4648
Tests are run either locally through docker
4749
or remotely through Modal.
4850
"""
49-
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
51+
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(dataset_name, split=dataset_split) # type: ignore
5052
spec = None
5153
example = None
5254
repo_name = None
@@ -56,10 +58,13 @@ def main(
5658
if "swe" in dataset_name.lower():
5759
repo_name = example["instance_id"]
5860
dataset_type = "swebench"
61+
elif "humaneval" in dataset_name.lower():
62+
repo_name = example["instance_id"]
63+
dataset_type = "simple"
5964
else:
6065
repo_name = example["repo"].split("/")[-1]
6166
dataset_type = "commit0"
62-
if repo_name in os.path.basename(repo_or_repo_dir):
67+
if repo_name in os.path.basename(repo_or_repo_dir) or repo_or_repo_dir.endswith(repo_name):
6368
spec = make_spec(example, dataset_type)
6469
break
6570
assert spec is not None, "No spec available"
@@ -73,46 +78,61 @@ def main(
7378
log_file = log_dir / "run_pytest.log"
7479
logger = setup_logger(repo_name, log_file, verbose=verbose)
7580

76-
try:
77-
local_repo = git.Repo(repo_or_repo_dir)
78-
logger.info(f"Loaded a git repo from {repo_or_repo_dir}")
79-
except (git.exc.NoSuchPathError, git.exc.InvalidGitRepositoryError): # type: ignore
80-
repo_dir = os.path.join(base_dir, repo_name)
81-
logger.error(f"{repo_or_repo_dir} is not a git dir, trying {repo_dir} again")
81+
if dataset_type != "simple": # if dataset_type is not simple, load git repo
8282
try:
83-
local_repo = git.Repo(repo_dir)
84-
logger.info(f"Retried succeeded. Loaded a git repo from {repo_dir}")
85-
except git.exc.NoSuchPathError: # type: ignore
86-
raise Exception(
87-
f"{repo_dir} and {repo_or_repo_dir} are not git directories.\nUsage: commit0 test {{repo_dir}} {{branch}} {{test_ids}}"
88-
)
89-
except Exception as e:
90-
raise e
91-
commit_id = ""
92-
if branch == "reference":
93-
commit_id = example["reference_commit"]
94-
else:
95-
# Check if it's a local branch
96-
if branch in local_repo.branches:
97-
commit_id = local_repo.commit(branch).hexsha
83+
local_repo = git.Repo(repo_or_repo_dir)
84+
logger.info(f"Loaded a git repo from {repo_or_repo_dir}")
85+
except (git.exc.NoSuchPathError, git.exc.InvalidGitRepositoryError): # type: ignore
86+
repo_dir = os.path.join(base_dir, repo_name)
87+
logger.error(f"{repo_or_repo_dir} is not a git dir, trying {repo_dir} again")
88+
try:
89+
local_repo = git.Repo(repo_dir)
90+
logger.info(f"Retried succeeded. Loaded a git repo from {repo_dir}")
91+
except git.exc.NoSuchPathError: # type: ignore
92+
raise Exception(
93+
f"{repo_dir} and {repo_or_repo_dir} are not git directories.\nUsage: commit0 test {{repo_dir}} {{branch}} {{test_ids}}"
94+
)
95+
except Exception as e:
96+
raise e
97+
commit_id = ""
98+
if branch == "reference":
99+
commit_id = example["reference_commit"]
98100
else:
99-
found_remote_branch = False
100-
for remote in local_repo.remotes:
101-
remote.fetch() # Fetch latest updates from each remote
101+
# Check if it's a local branch
102+
if branch in local_repo.branches:
103+
commit_id = local_repo.commit(branch).hexsha
104+
else:
105+
found_remote_branch = False
106+
for remote in local_repo.remotes:
107+
remote.fetch() # Fetch latest updates from each remote
102108

103-
# Check if the branch exists in this remote
104-
for ref in remote.refs:
105-
if (
106-
ref.remote_head == branch
107-
): # Compare branch name without remote prefix
108-
commit_id = local_repo.commit(ref.name).hexsha
109-
found_remote_branch = True
110-
break # Branch found, no need to keep checking this remote
111-
if found_remote_branch:
112-
break # Stop checking other remotes if branch is found
113-
if not found_remote_branch:
114-
raise Exception(f"Branch {branch} does not exist locally or remotely.")
115-
if "swe" in dataset_name.lower():
109+
# Check if the branch exists in this remote
110+
for ref in remote.refs:
111+
if (
112+
ref.remote_head == branch
113+
): # Compare branch name without remote prefix
114+
commit_id = local_repo.commit(ref.name).hexsha
115+
found_remote_branch = True
116+
break # Branch found, no need to keep checking this remote
117+
if found_remote_branch:
118+
break # Stop checking other remotes if branch is found
119+
if not found_remote_branch:
120+
raise Exception(f"Branch {branch} does not exist locally or remotely.")
121+
if dataset_type == "simple":
122+
if branch == "reference":
123+
patch = example["prompt"] + "\n\n" + example["canonical_solution"] + "\n\n" + example["test"]
124+
else:
125+
solution = open(test_ids).read()
126+
pattern = r"```python\n(.*?)```"
127+
matches = re.finditer(pattern, solution, re.DOTALL)
128+
matches = [match.group(1).strip() for match in matches]
129+
if len(matches) > 0:
130+
solution = "\n\n".join(matches)
131+
else:
132+
solution = example["prompt"] + "\n\n" + solution
133+
patch = solution + "\n\n" + example["test"]
134+
patch = patch + "\n\n" + f"check({example['entry_point']})"
135+
elif "swe" in dataset_name.lower():
116136
if branch == "reference":
117137
patch = example["test"]["patch"] + "\n\n" + example["test"]["test_patch"]
118138
else:
@@ -127,12 +147,15 @@ def main(
127147
patch_file = Path(log_dir / "patch.diff")
128148
patch_file.write_text(patch, encoding="utf-8", errors="ignore")
129149

130-
# make eval file
131-
if coverage:
132-
coverage_text = f" --cov={example['src_dir']} --cov-branch --cov-report json"
150+
if dataset_type != "simple":
151+
# make eval file
152+
if coverage:
153+
coverage_text = f" --cov={example['src_dir']} --cov-branch --cov-report json"
154+
else:
155+
coverage_text = ""
156+
eval_script = spec.eval_script.format(test_ids=test_ids, coverage=coverage_text)
133157
else:
134-
coverage_text = ""
135-
eval_script = spec.eval_script.format(test_ids=test_ids, coverage=coverage_text)
158+
eval_script = spec.eval_script
136159
eval_file = Path(log_dir / "eval.sh")
137160
eval_file.write_text(eval_script)
138161

commit0/harness/setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def main(
2323
base_dir: str,
2424
) -> None:
2525
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
26+
if "humaneval" in dataset_name.lower():
27+
return
2628
for example in dataset:
2729
repo_name = example["repo"].split("/")[-1]
2830
clone_url = f"https://github.com/{example['repo']}.git"

commit0/harness/spec.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from commit0.harness.constants import (
77
RepoInstance,
8+
SimpleInstance,
89
)
910
from commit0.harness.dockerfiles import (
1011
get_dockerfile_base,
@@ -19,7 +20,7 @@ class Spec(ABC):
1920
repo: str
2021
# repo dir on docker
2122
repo_directory: str
22-
instance: RepoInstance
23+
instance: Union[RepoInstance, SimpleInstance]
2324

2425
@property
2526
def setup_script(self) -> str:
@@ -175,6 +176,31 @@ def make_eval_script_list(self) -> list[str]:
175176
return eval_script_list
176177

177178

179+
class SimpleSpec(Spec):
180+
def make_repo_script_list(self) -> list[str]:
181+
"""Create a list of bash commands to set up the repository for testing.
182+
This is the setup script for the instance image.
183+
"""
184+
setup_commands = [
185+
f"mkdir {self.repo_directory} && cd {self.repo_directory}",
186+
f"uv venv --python 3.12",
187+
"source .venv/bin/activate",
188+
"which python",
189+
]
190+
return setup_commands
191+
192+
def make_eval_script_list(self) -> list[str]:
193+
"""Run the tests."""
194+
eval_script_list = [
195+
f"cd {self.repo_directory}",
196+
"source .venv/bin/activate",
197+
"cat /patch.diff > test.py",
198+
"uv run test.py > test_output.txt 2>&1",
199+
"echo $? > pytest_exit_code.txt",
200+
]
201+
return eval_script_list
202+
203+
178204
class SWEBenchSpec(Spec):
179205
def make_repo_script_list(self) -> list[str]:
180206
"""Create a list of bash commands to set up the repository for testing.
@@ -277,7 +303,7 @@ def make_eval_script_list(self) -> list[str]:
277303

278304

279305
def get_specs_from_dataset(
280-
dataset: Union[list[RepoInstance], list[Spec]], dataset_type: str
306+
dataset: Union[list[Union[RepoInstance, SimpleInstance]], list[Spec]], dataset_type: str
281307
) -> list[Spec]:
282308
"""Idempotent function that converts a list of RepoInstance objects to a list of Spec objects."""
283309
if isinstance(dataset[0], Spec):
@@ -290,7 +316,7 @@ def get_specs_from_dataset(
290316
)
291317

292318

293-
def make_spec(instance: RepoInstance, dataset_type: str) -> Spec:
319+
def make_spec(instance: Union[RepoInstance, SimpleInstance], dataset_type: str) -> Spec:
294320
if isinstance(instance, Spec):
295321
return instance
296322
repo_directory = "/testbed"
@@ -306,6 +332,12 @@ def make_spec(instance: RepoInstance, dataset_type: str) -> Spec:
306332
repo_directory=repo_directory,
307333
instance=instance,
308334
)
335+
elif dataset_type == "simple":
336+
return SimpleSpec(
337+
repo="simple", # all benchmarks with mere function writing will share the simple docker image
338+
repo_directory=repo_directory,
339+
instance=instance,
340+
)
309341
else:
310342
raise NotImplementedError(
311343
f"{dataset_type} is not supported.\nWe only support commit0 and swebench instances for now."

0 commit comments

Comments
 (0)