From b46c0cab31398b66b7f5fb20f5fd75a59f5019e4 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Tue, 1 Oct 2024 11:56:33 -0700 Subject: [PATCH] Update execute.py --- bigcode_eval/tasks/custom_metrics/execute.py | 34 +++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/bigcode_eval/tasks/custom_metrics/execute.py b/bigcode_eval/tasks/custom_metrics/execute.py index 53517a805..be90cd240 100644 --- a/bigcode_eval/tasks/custom_metrics/execute.py +++ b/bigcode_eval/tasks/custom_metrics/execute.py @@ -23,7 +23,18 @@ import platform import signal import tempfile +from uuid import uuid4 +_STORE = None + +def _get_store(): + """convinience method, to fetches same distributed dict and reuses same manager""" + global _STORE + if _STORE is None: + manager = multiprocessing.Manager() + store = manager.dict() + _STORE = store + return _STORE def check_correctness(check_program, timeout, task_id, completion_id): """ @@ -33,27 +44,25 @@ def check_correctness(check_program, timeout, task_id, completion_id): :param completion_id: an optional completion ID so we can match the results later even if execution finishes asynchronously. """ - manager = multiprocessing.Manager() - result = manager.list() - - p = multiprocessing.Process(target=unsafe_execute, args=(check_program, result, timeout)) + store = _get_store() + key = uuid4().hex + p = multiprocessing.Process(target=unsafe_execute, args=(check_program, store, key, timeout)) p.start() p.join(timeout=timeout + 1) if p.is_alive(): p.kill() - if not result: - result.append("timed out") + result = store.pop(key, "timed out") return dict( task_id=task_id, - passed=result[0] == "passed", - result=result[0], + passed=result == "passed", + result=result, completion_id=completion_id, ) -def unsafe_execute(check_program, result, timeout): +def unsafe_execute(check_program, store, key: str, timeout: int): with create_tempdir(): @@ -74,11 +83,12 @@ def unsafe_execute(check_program, result, timeout): with swallow_io(): with time_limit(timeout): exec(check_program, exec_globals) - result.append("passed") + store[key] = "passed" except TimeoutException: - result.append("timed out") + store[key] = f"timed out" except BaseException as e: - result.append(f"failed: {e}") + store[key] = f"failed: {e}" + # Needed for cleaning up. shutil.rmtree = rmtree