|
1 | 1 | import argparse
|
| 2 | +from datasets import load_dataset |
2 | 3 | import docker
|
3 | 4 | from enum import StrEnum, auto
|
4 | 5 | import os
|
|
7 | 8 | from pathlib import Path
|
8 | 9 | import logging
|
9 | 10 |
|
| 11 | +from omegaconf import DictConfig, OmegaConf |
| 12 | +import hydra |
| 13 | + |
10 | 14 | from commit0.harness.constants import RUN_PYTEST_LOG_DIR
|
11 | 15 | from commit0.harness.docker_build import (
|
12 | 16 | close_logger,
|
|
32 | 36 |
|
33 | 37 |
|
34 | 38 | class ExecutionBackend(StrEnum):
|
35 |
| - DOCKER = auto() |
| 39 | + LOCAL = auto() |
36 | 40 | MODAL = auto()
|
37 | 41 |
|
38 | 42 |
|
@@ -194,72 +198,41 @@ def run_modal(
|
194 | 198 | )
|
195 | 199 |
|
196 | 200 |
|
197 |
| -def main( |
198 |
| - repo: str, |
199 |
| - test_ids_ls: list[str], |
200 |
| - timeout: int, |
201 |
| - branch_name: str, |
202 |
| - backend: ExecutionBackend, |
203 |
| -) -> None: |
204 |
| - with open("config.yml", "r") as file: |
205 |
| - data = yaml.safe_load(file) |
206 |
| - spec = make_spec(data["repos"][repo]) |
207 |
| - test_ids = " ".join(test_ids_ls) |
208 |
| - hashed_test_ids = get_hash_string(test_ids) |
| 201 | +@hydra.main(version_base=None, config_path="configs", config_name="base") |
| 202 | +def main(config: DictConfig) -> None: |
| 203 | + OmegaConf.to_yaml(config) |
| 204 | + dataset = load_dataset(config.dataset_name, split="test") |
| 205 | + for example in dataset: |
| 206 | + if example["repo"].endswith(config.repo): |
| 207 | + spec = make_spec(example) |
| 208 | + break |
209 | 209 |
|
| 210 | + hashed_test_ids = get_hash_string(config.test_ids) |
210 | 211 | # set up logging
|
211 |
| - log_dir = RUN_PYTEST_LOG_DIR / repo / hashed_test_ids |
| 212 | + log_dir = RUN_PYTEST_LOG_DIR / config.repo / hashed_test_ids |
212 | 213 | log_dir.mkdir(parents=True, exist_ok=True)
|
213 | 214 | log_file = log_dir / "run_pytest.log"
|
214 |
| - logger = setup_logger(repo, log_file) |
| 215 | + logger = setup_logger(config.repo, log_file) |
215 | 216 |
|
216 | 217 | # make eval file
|
217 | 218 | eval_script = spec.eval_script.format(
|
218 |
| - local_repo=f"{data['base_repo_dir']}/{repo}", |
219 |
| - branch_name=branch_name, |
220 |
| - test_ids=test_ids, |
221 |
| - ip=get_ip(data["backend"]), |
| 219 | + local_repo=f"{config.base_dir}/{config.repo}", |
| 220 | + branch_name=config.branch, |
| 221 | + test_ids=config.test_ids, |
| 222 | + ip=get_ip(config.backend), |
222 | 223 | user=get_user(),
|
223 | 224 | )
|
224 | 225 | eval_file = Path(log_dir / "eval.sh")
|
225 | 226 | eval_file.write_text(eval_script)
|
226 | 227 |
|
227 |
| - if ExecutionBackend(backend) == ExecutionBackend.DOCKER: |
228 |
| - run_docker(spec, logger, eval_file, timeout, log_dir) |
229 |
| - elif ExecutionBackend(backend) == ExecutionBackend.MODAL: |
230 |
| - run_modal(spec, logger, eval_file, timeout, log_dir) |
231 |
| - |
232 |
| - |
233 |
| -def add_init_args(parser: argparse.ArgumentParser) -> None: |
234 |
| - parser.add_argument("--repo", type=str, help="which repo to run unit tests") |
235 |
| - parser.add_argument( |
236 |
| - "--test_ids", type=str, nargs="+", help="which test ids / files / directories" |
237 |
| - ) |
238 |
| - parser.add_argument( |
239 |
| - "--branch_name", type=str, help="which git branch to run unit tests" |
240 |
| - ) |
241 |
| - parser.add_argument( |
242 |
| - "--timeout", |
243 |
| - type=int, |
244 |
| - default=1_800, |
245 |
| - help="Timeout (in seconds) for running tests for each instance", |
246 |
| - ) |
247 |
| - parser.add_argument( |
248 |
| - "--backend", |
249 |
| - choices=[backend.value for backend in ExecutionBackend], |
250 |
| - default=ExecutionBackend.DOCKER.value, |
251 |
| - help="Execution backend [docker, modal]", |
252 |
| - ) |
| 228 | + if ExecutionBackend(config.backend) == ExecutionBackend.LOCAL: |
| 229 | + run_docker(spec, logger, eval_file, config.timeout, log_dir) |
| 230 | + elif ExecutionBackend(config.backend) == ExecutionBackend.MODAL: |
| 231 | + run_modal(spec, logger, eval_file, config.timeout, log_dir) |
253 | 232 |
|
254 | 233 |
|
255 | 234 | def run(args: argparse.Namespace) -> None:
|
256 |
| - main( |
257 |
| - repo=args.repo, |
258 |
| - test_ids_ls=args.test_ids, |
259 |
| - timeout=args.timeout, |
260 |
| - branch_name=args.branch_name, |
261 |
| - backend=args.backend, |
262 |
| - ) |
| 235 | + main() |
263 | 236 |
|
264 | 237 |
|
265 | 238 | __all__ = []
|
0 commit comments