Skip to content

Commit 3ed0d9a

Browse files
authored
Merge pull request #7 from commit-0/modal
Modal execution
2 parents 667c5ad + 44631f9 commit 3ed0d9a

File tree

2 files changed

+138
-22
lines changed

2 files changed

+138
-22
lines changed

commit0/harness/run_pytest_ids.py

Lines changed: 137 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import argparse
22
import docker
3+
from enum import StrEnum, auto
4+
import modal
5+
import os
36
import traceback
47
import yaml
58
from pathlib import Path
@@ -28,30 +31,12 @@
2831
)
2932

3033

31-
def main(repo: str, test_ids: list[str], timeout: int, branch_name: str) -> None:
32-
with open("config.yml", "r") as file:
33-
data = yaml.safe_load(file)
34-
spec = make_spec(data["repos"][repo])
35-
test_ids = " ".join(test_ids)
36-
hashed_test_ids = get_hash_string(test_ids)
34+
class ExecutionBackend(StrEnum):
35+
DOCKER = auto()
36+
MODAL = auto()
3737

38-
# set up logging
39-
log_dir = RUN_PYTEST_LOG_DIR / repo / hashed_test_ids
40-
log_dir.mkdir(parents=True, exist_ok=True)
41-
log_file = log_dir / "run_pytest.log"
42-
logger = setup_logger(repo, log_file)
43-
44-
# make eval file
45-
eval_script = spec.eval_script.format(
46-
local_repo=f"{data['base_repo_dir']}/{repo}",
47-
branch_name=branch_name,
48-
test_ids=test_ids,
49-
ip=get_ip(data["backend"]),
50-
user=get_user(),
51-
)
52-
eval_file = Path(log_dir / "eval.sh")
53-
eval_file.write_text(eval_script)
5438

39+
def run_docker(spec, logger, eval_file, timeout, log_dir):
5540
client = docker.from_env()
5641
container = None
5742
try:
@@ -115,6 +100,130 @@ def main(repo: str, test_ids: list[str], timeout: int, branch_name: str) -> None
115100
close_logger(logger)
116101

117102

103+
def run_modal(spec, logger, eval_file, timeout, log_dir):
104+
# get image name to pull from dockerhub
105+
# spec.repo_image_key
106+
reponame = spec.repo.split("/")[-1]
107+
image_name = f"wentingzhao/{reponame}"
108+
image = modal.Image.from_registry(image_name)
109+
110+
with modal.NetworkFileSystem.ephemeral() as nfs:
111+
# create sleepy sandbox
112+
sandbox = modal.Sandbox.create(
113+
"sleep",
114+
"infinity",
115+
image=image,
116+
network_file_systems={
117+
"/vol": nfs,
118+
},
119+
)
120+
121+
# get ssh pubkey
122+
process = sandbox.exec("bash", "-c", "cat /root/.ssh/id_rsa.pub")
123+
public_key = "".join([line for line in process.stdout]).strip()
124+
125+
# add to authorized keys locally. copy-pasted from utils
126+
local_authorized_keys_path = os.path.expanduser("~/.ssh/authorized_keys")
127+
os.makedirs(os.path.dirname(local_authorized_keys_path), exist_ok=True)
128+
if not os.path.exists(local_authorized_keys_path):
129+
# Since the file does not exist, create it
130+
open(local_authorized_keys_path, "a").close()
131+
write = True
132+
else:
133+
with open(local_authorized_keys_path, "r") as authorized_keys_file:
134+
content = authorized_keys_file.read()
135+
if public_key not in content:
136+
write = True
137+
else:
138+
write = False
139+
if write:
140+
with open(local_authorized_keys_path, "a") as authorized_keys_file:
141+
authorized_keys_file.write(public_key + "\n")
142+
143+
# copy eval file
144+
with open(eval_file, "rb") as f:
145+
nfs.write_file("eval.sh", f)
146+
sandbox.exec("bash", "-c", "cp /vol/eval.sh /eval.sh")
147+
148+
# DBG: check if eval file properly copied
149+
process = sandbox.exec("bash", "-c", "ls /")
150+
for line in process.stdout:
151+
print(line)
152+
# /DBG
153+
154+
# execute tests
155+
process = sandbox.exec("bash", "-c", "/bin/bash /eval.sh")
156+
output = []
157+
for line in process.stdout:
158+
output.append(line)
159+
output = "".join(line)
160+
logger.info(output)
161+
print(output)
162+
163+
output = []
164+
for line in process.stderr:
165+
output.append(line)
166+
output = "".join(line)
167+
logger.info(output)
168+
print(output)
169+
170+
timed_out = False
171+
total_runtime = 1
172+
173+
test_output = extract_test_output(
174+
output, "--json-report --json-report-file=report.json"
175+
)
176+
177+
# stdout might be more straightforward
178+
print(test_output)
179+
test_output_path = log_dir / "test_output.txt"
180+
with open(test_output_path, "w") as f:
181+
f.write(test_output)
182+
if timed_out:
183+
f.write(f"\n\nTimeout error: {timeout} seconds exceeded.")
184+
raise EvaluationError(
185+
repo,
186+
f"Test timed out after {timeout} seconds.",
187+
logger,
188+
)
189+
190+
191+
def main(
192+
repo: str,
193+
test_ids: list[str],
194+
timeout: int,
195+
branch_name: str,
196+
backend: ExecutionBackend,
197+
) -> None:
198+
with open("config.yml", "r") as file:
199+
data = yaml.safe_load(file)
200+
spec = make_spec(data["repos"][repo])
201+
test_ids = " ".join(test_ids)
202+
hashed_test_ids = get_hash_string(test_ids)
203+
204+
# set up logging
205+
log_dir = RUN_PYTEST_LOG_DIR / repo / hashed_test_ids
206+
log_dir.mkdir(parents=True, exist_ok=True)
207+
log_file = log_dir / "run_pytest.log"
208+
logger = setup_logger(repo, log_file)
209+
210+
# make eval file
211+
eval_script = spec.eval_script.format(
212+
local_repo=f"{data['base_repo_dir']}/{repo}",
213+
branch_name=branch_name,
214+
test_ids=test_ids,
215+
ip=get_ip(data["backend"]),
216+
user=get_user(),
217+
)
218+
eval_file = Path(log_dir / "eval.sh")
219+
eval_file.write_text(eval_script)
220+
221+
if ExecutionBackend(backend) == ExecutionBackend.DOCKER:
222+
run_docker(spec, logger, eval_file, timeout, log_dir)
223+
elif ExecutionBackend(backend) == ExecutionBackend.MODAL:
224+
run_modal(spec, logger, eval_file, timeout, log_dir)
225+
226+
118227
if __name__ == "__main__":
119228
parser = argparse.ArgumentParser()
120229
parser.add_argument("--repo", type=str, help="which repo to run unit tests")
@@ -130,5 +239,11 @@ def main(repo: str, test_ids: list[str], timeout: int, branch_name: str) -> None
130239
default=1_800,
131240
help="Timeout (in seconds) for running tests for each instance",
132241
)
242+
parser.add_argument(
243+
"--backend",
244+
choices=[backend.value for backend in ExecutionBackend],
245+
default=ExecutionBackend.DOCKER.value,
246+
help="Execution backend [docker, modal]",
247+
)
133248
args = parser.parse_args()
134249
main(**vars(args))

commit0/harness/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import socket
66
import os
7+
import requests
78
from commit0.harness.constants import EVAL_BACKENDS
89

910

0 commit comments

Comments
 (0)