Skip to content

Commit

Permalink
WIP cmdline/slurm
Browse files Browse the repository at this point in the history
  • Loading branch information
skrawcz committed Jan 15, 2024
1 parent 93733f0 commit d3c08a5
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 0 deletions.
57 changes: 57 additions & 0 deletions examples/cmdline_orchestrator/cmdline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import functools
import subprocess

from hamilton.execution.executors import DefaultExecutionManager, TaskExecutor
from hamilton.execution.grouping import TaskImplementation


class CMDLineExecutionManager(DefaultExecutionManager):
def get_executor_for_task(self, task: TaskImplementation) -> TaskExecutor:
"""Simple implementation that returns the local executor for single task executions,
:param task: Task to get executor for
:return: A local task if this is a "single-node" task, a remote task otherwise
"""
is_single_node_task = len(task.nodes) == 1
if not is_single_node_task:
raise ValueError("Only single node tasks supported")
(node,) = task.nodes
if "cmdline" in node.tags: # hard coded for now
return self.remote_executor
return self.local_executor


import inspect


def cmdline_decorator(func):
"""Decorator to run the result of a function as a command line command."""

@functools.wraps(func)
def wrapper(*args, **kwargs):
if inspect.isgeneratorfunction(func):
# If the function is a generator, then we need to run it and capture the output
# in order to return it
gen = func(*args, **kwargs)
cmd = next(gen)
# Run the command and capture the output
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
try:
gen.send(result)
raise ValueError("Generator cannot have multiple yields.")
except StopIteration as e:
return e.value
else:
# Get the command from the function
cmd = func(*args, **kwargs)

# Run the command and capture the output
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)

# Return the output
return result.stdout

if inspect.isgeneratorfunction(func):
# get the return type and set it as the return type of the wrapper
wrapper.__annotations__["return"] = inspect.signature(func).return_annotation[2]
return wrapper
41 changes: 41 additions & 0 deletions examples/cmdline_orchestrator/funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import time
from subprocess import CompletedProcess

from cmdline import cmdline_decorator

from hamilton.function_modifiers import tag


@tag(cmdline="yes", cache="pickle")
@cmdline_decorator
def echo_1(start: str) -> str:
time.sleep(2)
return f'echo "1: {start}"'


@tag(cmdline="yes", cache="pickle")
@cmdline_decorator
def echo_2(echo_1: str) -> str:
time.sleep(2)
return f'echo "2: {echo_1}"'


@tag(cmdline="yes", cache="pickle")
@cmdline_decorator
def echo_2b(echo_1: str) -> [str, CompletedProcess, str]:
# preprocess
print("preprocess")
time.sleep(2)
msg = f'echo "2b: {echo_1}"'
completed_process = yield msg
# postprocess
print("postprocess")
time.sleep(2)
output = completed_process.stdout + "!!!"
return output


@tag(cmdline="yes", cache="pickle")
@cmdline_decorator
def echo_3(echo_2: str, echo_2b: str) -> str:
return f'echo "3: {echo_2 + ":::" + echo_2b}"'
39 changes: 39 additions & 0 deletions examples/cmdline_orchestrator/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os

from hamilton.execution.executors import MultiThreadingExecutor, SynchronousLocalTaskExecutor
from hamilton.experimental.h_cache import CachingGraphAdapter

if __name__ == "__main__":
import funcs
from cmdline import CMDLineExecutionManager
from dagworks import adapters

from hamilton import driver

tracker = adapters.DAGWorksTracker(
username="[email protected]",
api_key=os.environ["DAGWORKS_API_KEY"],
project_id=os.environ["DAGWORKS_PROJECT_ID"],
dag_name="toy-cmdline-dag",
tags={"env": "local"}, # , "TODO": "add_more_tags_to_find_your_run_later"},
)

dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_execution_manager(
CMDLineExecutionManager(SynchronousLocalTaskExecutor(), MultiThreadingExecutor(5))
)
.with_modules(funcs)
.with_adapters(
tracker,
CachingGraphAdapter("./cache"),
# PrintLnHook()
)
.build()
)
dr.display_all_functions("graph.dot")
print(dr.list_available_variables())
# for var in dr.list_available_variables():
# print(dr.execute([var.name], inputs={"start": "hello"}))
print(dr.execute(["echo_3"], inputs={"start": "hello"}))

0 comments on commit d3c08a5

Please sign in to comment.