Skip to content

Commit 94c5284

Browse files
committed
extending to more datasets
1 parent 08617ff commit 94c5284

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

commit0/cli.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,19 @@ def test(
257257
if reference:
258258
branch = "reference"
259259
else:
260-
if "humaneval" not in commit0_config["dataset_name"].split("/")[-1].lower():
260+
dataset_name = commit0_config["dataset_name"].lower()
261+
if "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name:
262+
branch = repo_or_repo_path
263+
else:
261264
if branch is None and not reference:
262265
git_path = os.path.join(
263266
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
264267
)
265268
branch = get_active_branch(git_path)
266-
else:
267-
branch = test_ids
268269

269270
if stdin:
270271
# Read test names from stdin
271-
test_ids = sys.stdin.read().strip()
272+
test_ids = sys.stdin.read()
272273
elif test_ids is None:
273274
typer.echo("Error: test_ids must be provided or use --stdin option", err=True)
274275
raise typer.Exit(code=1)

commit0/harness/build.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ def main(
2525
dataset_name, split=dataset_split
2626
) # type: ignore
2727
specs = []
28-
if "swe" in dataset_name.lower():
28+
dataset_name = dataset_name.lower()
29+
if "swe" in dataset_name:
2930
dataset_type = "swebench"
30-
elif "humaneval" in dataset_name.lower():
31+
elif "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name:
3132
dataset_type = "simple"
3233
else:
3334
dataset_type = "commit0"
3435
for example in dataset:
35-
if "swe" in dataset_name.lower() or dataset_type == "simple":
36+
if "swe" in dataset_name or dataset_type == "simple":
3637
if split != "all" and split not in example["instance_id"]:
3738
continue
3839
else:

commit0/harness/run_pytest_ids.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,18 @@ def main(
5151
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(
5252
dataset_name, split=dataset_split
5353
) # type: ignore
54+
dataset_name = dataset_name.lower()
5455
spec = None
5556
example = None
5657
repo_name = None
5758
dataset_type = None
5859
for example in dataset:
5960
if repo_or_repo_dir.endswith("/"):
6061
repo_or_repo_dir = repo_or_repo_dir[:-1]
61-
if "swe" in dataset_name.lower():
62+
if "swe" in dataset_name:
6263
repo_name = example["instance_id"]
6364
dataset_type = "swebench"
64-
elif "humaneval" in dataset_name.lower():
65+
elif "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name:
6566
repo_name = example["instance_id"]
6667
dataset_type = "simple"
6768
else:
@@ -130,7 +131,7 @@ def main(
130131
)
131132

132133
# make patch file
133-
if "swe" in dataset_name.lower():
134+
if "swe" in dataset_name:
134135
if branch == "reference":
135136
patch = (
136137
example["test"]["patch"] + "\n\n" + example["test"]["test_patch"]
@@ -164,7 +165,7 @@ def main(
164165
+ example["test"]
165166
)
166167
else:
167-
solution = open(test_ids).read()
168+
solution = test_ids
168169
prompt = example["prompt"] if "prompt" in example.keys() else ""
169170
matches = extract_code_blocks(solution)
170171
if len(matches) > 0:

commit0/harness/setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ def main(
2323
base_dir: str,
2424
) -> None:
2525
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
26-
if "humaneval" in dataset_name.lower():
26+
dataset_name = dataset_name.lower()
27+
if "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name:
2728
return
2829
for example in dataset:
2930
repo_name = example["repo"].split("/")[-1]
3031
clone_url = f"https://github.com/{example['repo']}.git"
31-
if "swe" in dataset_name.lower():
32+
if "swe" in dataset_name:
3233
if repo_split != "all" and repo_split not in example["instance_id"]:
3334
continue
3435
clone_dir = os.path.abspath(os.path.join(base_dir, example["instance_id"]))

0 commit comments

Comments
 (0)