diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 9cfa34a3..3461909a 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -26,7 +26,7 @@ jobs: virtualenvs-in-project: true - name: Install linting tools - run: poetry install --no-root --only lint + run: poetry install --no-root - name: Run Black run: poetry run black . --check --verbose --diff --color diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 3524502a..746a346c 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -54,16 +54,10 @@ class ChildTriggerWorkflowOptions(TypedDict): sticky: bool | None = None -class WorkflowRunDict(TypedDict): - workflow_name: str - input: Any - options: Optional[dict] - - class ChildWorkflowRunDict(TypedDict): workflow_name: str input: Any - options: ChildTriggerWorkflowOptions[dict] + options: ChildTriggerWorkflowOptions key: str | None = None @@ -73,6 +67,12 @@ class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict): namespace: str | None = None +class WorkflowRunDict(TypedDict): + workflow_name: str + input: Any + options: TriggerWorkflowOptions | None + + class DedupeViolationErr(Exception): """Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value.""" @@ -260,7 +260,9 @@ async def run_workflow( @tenacity_retry async def run_workflows( - self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None + self, + workflows: list[WorkflowRunDict], + options: TriggerWorkflowOptions | None = None, ) -> List[WorkflowRunRef]: if len(workflows) == 0: raise ValueError("No workflows to run") diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index cd1c195b..fc2887bd 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -61,7 +61,7 @@ class Action: worker_id: str tenant_id: str workflow_run_id: str - get_group_key_run_id: Optional[str] + get_group_key_run_id: str job_id: str job_name: str job_run_id: str diff --git a/hatchet_sdk/clients/dispatcher/dispatcher.py b/hatchet_sdk/clients/dispatcher/dispatcher.py index b92addab..6578528b 100644 --- a/hatchet_sdk/clients/dispatcher/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -137,14 +137,14 @@ def put_overrides_data(self, data: OverridesData): return response - def release_slot(self, step_run_id: str): + def release_slot(self, step_run_id: str) -> None: self.client.ReleaseSlot( ReleaseSlotRequest(stepRunId=step_run_id), timeout=DEFAULT_REGISTER_TIMEOUT, metadata=get_metadata(self.token), ) - def refresh_timeout(self, step_run_id: str, increment_by: str): + def refresh_timeout(self, step_run_id: str, increment_by: str) -> None: self.client.RefreshTimeout( RefreshTimeoutRequest( stepRunId=step_run_id, diff --git a/hatchet_sdk/clients/rest/tenacity_utils.py b/hatchet_sdk/clients/rest/tenacity_utils.py index 42f51394..55fe6d18 100644 --- a/hatchet_sdk/clients/rest/tenacity_utils.py +++ b/hatchet_sdk/clients/rest/tenacity_utils.py @@ -1,10 +1,15 @@ +from typing import Callable, ParamSpec, TypeVar + import grpc import tenacity from hatchet_sdk.logger import logger +P = ParamSpec("P") +R = TypeVar("R") + -def tenacity_retry(func): +def tenacity_retry(func: Callable[P, R]) -> Callable[P, R]: return tenacity.retry( reraise=True, wait=tenacity.wait_exponential_jitter(), diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index 5f6c8726..e7ccb36b 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -2,7 +2,9 @@ import json import traceback from concurrent.futures import Future, ThreadPoolExecutor -from typing import List +from typing import Any, cast + +from pydantic import StrictStr from hatchet_sdk.clients.events import EventClient from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry @@ -10,8 +12,8 @@ from hatchet_sdk.clients.run_event_listener import RunEventListenerClient from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener from hatchet_sdk.context.worker_context import WorkerContext -from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData -from hatchet_sdk.contracts.workflows_pb2 import ( +from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData # type: ignore +from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] BulkTriggerWorkflowRequest, TriggerWorkflowRequest, ) @@ -24,25 +26,32 @@ TriggerWorkflowOptions, WorkflowRunDict, ) -from ..clients.dispatcher.dispatcher import Action, DispatcherClient +from ..clients.dispatcher.dispatcher import ( # type: ignore[attr-defined] + Action, + DispatcherClient, +) from ..logger import logger DEFAULT_WORKFLOW_POLLING_INTERVAL = 5 # Seconds -def get_caller_file_path(): +def get_caller_file_path() -> str: caller_frame = inspect.stack()[2] return caller_frame.filename class BaseContext: + + action: Action + spawn_index: int + def _prepare_workflow_options( self, - key: str = None, + key: str | None = None, options: ChildTriggerWorkflowOptions | None = None, - worker_id: str = None, - ): + worker_id: str | None = None, + ) -> TriggerWorkflowOptions: workflow_run_id = self.action.workflow_run_id step_run_id = self.action.step_run_id @@ -54,7 +63,7 @@ def _prepare_workflow_options( if options is not None and "additional_metadata" in options: meta = options["additional_metadata"] - trigger_options: TriggerWorkflowOptions = { + trigger_options: TriggerWorkflowOptions = { # type: ignore[typeddict-item] "parent_id": workflow_run_id, "parent_step_run_id": step_run_id, "child_key": key, @@ -95,9 +104,9 @@ def __init__( async def spawn_workflow( self, workflow_name: str, - input: dict = {}, - key: str = None, - options: ChildTriggerWorkflowOptions = None, + input: dict[str, Any] = {}, + key: str | None = None, + options: ChildTriggerWorkflowOptions | None = None, ) -> WorkflowRunRef: worker_id = self.worker.id() # if ( @@ -118,15 +127,15 @@ async def spawn_workflow( @tenacity_retry async def spawn_workflows( - self, child_workflow_runs: List[ChildWorkflowRunDict] - ) -> List[WorkflowRunRef]: + self, child_workflow_runs: list[ChildWorkflowRunDict] + ) -> list[WorkflowRunRef]: if len(child_workflow_runs) == 0: raise Exception("no child workflows to spawn") worker_id = self.worker.id() - bulk_trigger_workflow_runs: WorkflowRunDict = [] + bulk_trigger_workflow_runs: list[WorkflowRunDict] = [] for child_workflow_run in child_workflow_runs: workflow_name = child_workflow_run["workflow_name"] input = child_workflow_run["input"] @@ -134,7 +143,8 @@ async def spawn_workflows( key = child_workflow_run.get("key") options = child_workflow_run.get("options", {}) - trigger_options = self._prepare_workflow_options(key, options, worker_id) + ## TODO: figure out why this is failing + trigger_options = self._prepare_workflow_options(key, options, worker_id) # type: ignore[arg-type] bulk_trigger_workflow_runs.append( WorkflowRunDict( @@ -179,11 +189,11 @@ def __init__( # Check the type of action.action_payload before attempting to load it as JSON if isinstance(action.action_payload, (str, bytes, bytearray)): try: - self.data = json.loads(action.action_payload) + self.data = cast(dict[str, Any], json.loads(action.action_payload)) except Exception as e: logger.error(f"Error parsing action payload: {e}") # Assign an empty dictionary if parsing fails - self.data = {} + self.data: dict[str, Any] = {} # type: ignore[no-redef] else: # Directly assign the payload to self.data if it's already a dict self.data = ( @@ -218,33 +228,34 @@ def __init__( else: self.input = self.data.get("input", {}) - def step_output(self, step: str): + def step_output(self, step: str) -> dict[str, Any]: try: - return self.data["parents"][step] + return cast(dict[str, Any], self.data["parents"][step]) except KeyError: raise ValueError(f"Step output for '{step}' not found") def triggered_by_event(self) -> bool: - return self.data.get("triggered_by", "") == "event" + return cast(str, self.data.get("triggered_by", "")) == "event" - def workflow_input(self): + def workflow_input(self) -> dict[str, Any]: return self.input - def workflow_run_id(self): + def workflow_run_id(self) -> str: return self.action.workflow_run_id - def cancel(self): + def cancel(self) -> None: logger.debug("cancelling step...") self.exit_flag = True # done returns true if the context has been cancelled - def done(self): + def done(self) -> bool: return self.exit_flag - def playground(self, name: str, default: str = None): + def playground(self, name: str, default: str | None = None) -> str | None: # if the key exists in the overrides_data field, return the value if name in self.overrides_data: - return self.overrides_data[name] + ## TODO: Check if this is the right type + return cast(str, self.overrides_data[name]) caller_file = get_caller_file_path() @@ -259,7 +270,7 @@ def playground(self, name: str, default: str = None): return default - def _log(self, line: str) -> (bool, Exception): # type: ignore + def _log(self, line: str) -> tuple[bool, Exception | None]: try: self.event_client.log(message=line, step_run_id=self.stepRunId) return True, None @@ -267,7 +278,7 @@ def _log(self, line: str) -> (bool, Exception): # type: ignore # we don't want to raise an exception here, as it will kill the log thread return False, e - def log(self, line, raise_on_error: bool = False): + def log(self, line: Any, raise_on_error: bool = False) -> None: if self.stepRunId == "": return @@ -277,9 +288,9 @@ def log(self, line, raise_on_error: bool = False): except Exception: line = str(line) - future: Future = self.logger_thread_pool.submit(self._log, line) + future = self.logger_thread_pool.submit(self._log, line) - def handle_result(future: Future): + def handle_result(future: Future[tuple[bool, Exception | None]]) -> None: success, exception = future.result() if not success and exception: if raise_on_error: @@ -297,22 +308,22 @@ def handle_result(future: Future): future.add_done_callback(handle_result) - def release_slot(self): + def release_slot(self) -> None: return self.dispatcher_client.release_slot(self.stepRunId) - def _put_stream(self, data: str | bytes): + def _put_stream(self, data: str | bytes) -> None: try: self.event_client.stream(data=data, step_run_id=self.stepRunId) except Exception as e: logger.error(f"Error putting stream event: {e}") - def put_stream(self, data: str | bytes): + def put_stream(self, data: str | bytes) -> None: if self.stepRunId == "": return self.stream_event_thread_pool.submit(self._put_stream, data) - def refresh_timeout(self, increment_by: str): + def refresh_timeout(self, increment_by: str) -> None: try: return self.dispatcher_client.refresh_timeout( step_run_id=self.stepRunId, increment_by=increment_by @@ -320,28 +331,28 @@ def refresh_timeout(self, increment_by: str): except Exception as e: logger.error(f"Error refreshing timeout: {e}") - def retry_count(self): + def retry_count(self) -> int: return self.action.retry_count - def additional_metadata(self): + def additional_metadata(self) -> dict[str, Any] | None: return self.action.additional_metadata - def child_index(self): + def child_index(self) -> int | None: return self.action.child_workflow_index - def child_key(self): + def child_key(self) -> str | None: return self.action.child_workflow_key - def parent_workflow_run_id(self): + def parent_workflow_run_id(self) -> str | None: return self.action.parent_workflow_run_id - def fetch_run_failures(self): + def fetch_run_failures(self) -> list[dict[str, StrictStr]]: data = self.rest_client.workflow_run_get(self.action.workflow_run_id) other_job_runs = [ - run for run in data.job_runs if run.job_id != self.action.job_id + run for run in (data.job_runs or []) if run.job_id != self.action.job_id ] # TODO: Parse Step Runs using a Pydantic Model rather than a hand crafted dictionary - failed_step_runs = [ + return [ { "step_id": step_run.step_id, "step_run_action_name": step_run.step.action, @@ -350,7 +361,5 @@ def fetch_run_failures(self): for job_run in other_job_runs if job_run.step_runs for step_run in job_run.step_runs - if step_run.error + if step_run.error and step_run.step ] - - return failed_step_runs diff --git a/hatchet_sdk/context/worker_context.py b/hatchet_sdk/context/worker_context.py index 96cc76bc..dfd111bb 100644 --- a/hatchet_sdk/context/worker_context.py +++ b/hatchet_sdk/context/worker_context.py @@ -21,7 +21,7 @@ async def async_upsert_labels(self, labels: dict[str, str | int]): await self.client.async_upsert_worker_labels(self._worker_id, labels) self._labels.update(labels) - def id(self): + def id(self) -> str: return self._worker_id # def has_workflow(self, workflow_name: str): diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index c08a2803..987584a7 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, Callable, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Optional, Type from typing_extensions import deprecated @@ -25,14 +25,15 @@ from .clients.run_event_listener import RunEventListenerClient from .logger import logger from .worker.worker import Worker -from .workflow import ConcurrencyExpression, WorkflowMeta - -P = ParamSpec("P") -R = TypeVar("R") +from .workflow import ( + ConcurrencyExpression, + WorkflowInterface, + WorkflowMeta, + WorkflowStepProtocol, +) -## TODO: Fix return type here to properly type hint the metaclass -def workflow( # type: ignore[no-untyped-def] +def workflow( name: str = "", on_events: list[str] | None = None, on_crons: list[str] | None = None, @@ -42,11 +43,11 @@ def workflow( # type: ignore[no-untyped-def] sticky: StickyStrategy = None, default_priority: int | None = None, concurrency: ConcurrencyExpression | None = None, -): +) -> Callable[[Type[WorkflowInterface]], WorkflowMeta]: on_events = on_events or [] on_crons = on_crons or [] - def inner(cls: Any) -> WorkflowMeta: + def inner(cls: Type[WorkflowInterface]) -> WorkflowMeta: cls.on_events = on_events cls.on_crons = on_crons cls.name = name or str(cls.__name__) @@ -60,7 +61,7 @@ def inner(cls: Any) -> WorkflowMeta: # with WorkflowMeta as its metaclass ## TODO: Figure out how to type this metaclass correctly - return WorkflowMeta(cls.name, cls.__bases__, dict(cls.__dict__)) # type: ignore[no-untyped-call] + return WorkflowMeta(cls.name, cls.__bases__, dict(cls.__dict__)) return inner @@ -72,10 +73,10 @@ def step( retries: int = 0, rate_limits: list[RateLimit] | None = None, desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, -) -> Callable[[Callable[P, R]], Callable[P, R]]: +) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]: parents = parents or [] - def inner(func: Callable[P, R]) -> Callable[P, R]: + def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol: limits = None if rate_limits: limits = [ @@ -83,18 +84,17 @@ def inner(func: Callable[P, R]) -> Callable[P, R]: for rate_limit in rate_limits or [] ] - ## TODO: Use Protocol here to help with MyPy errors - func._step_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined] - func._step_parents = parents # type: ignore[attr-defined] - func._step_timeout = timeout # type: ignore[attr-defined] - func._step_retries = retries # type: ignore[attr-defined] - func._step_rate_limits = limits # type: ignore[attr-defined] + func._step_name = name.lower() or str(func.__name__).lower() + func._step_parents = parents + func._step_timeout = timeout + func._step_retries = retries + func._step_rate_limits = limits - func._step_desired_worker_labels = {} # type: ignore[attr-defined] + func._step_desired_worker_labels = {} for key, d in desired_worker_labels.items(): value = d["value"] if "value" in d else None - func._step_desired_worker_labels[key] = DesiredWorkerLabels( # type: ignore[attr-defined] + func._step_desired_worker_labels[key] = DesiredWorkerLabels( strValue=str(value) if not isinstance(value, int) else None, intValue=value if isinstance(value, int) else None, required=d["required"] if "required" in d else None, @@ -112,8 +112,8 @@ def on_failure_step( timeout: str = "", retries: int = 0, rate_limits: list[RateLimit] | None = None, -) -> Callable[[Callable[P, R]], Callable[P, R]]: - def inner(func: Callable[P, R]) -> Callable[P, R]: +) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]: + def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol: limits = None if rate_limits: limits = [ @@ -121,11 +121,10 @@ def inner(func: Callable[P, R]) -> Callable[P, R]: for rate_limit in rate_limits or [] ] - ## TODO: Use Protocol here to help with MyPy errors - func._on_failure_step_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined] - func._on_failure_step_timeout = timeout # type: ignore[attr-defined] - func._on_failure_step_retries = retries # type: ignore[attr-defined] - func._on_failure_step_rate_limits = limits # type: ignore[attr-defined] + func._on_failure_step_name = name.lower() or str(func.__name__).lower() + func._on_failure_step_timeout = timeout + func._on_failure_step_retries = retries + func._on_failure_step_rate_limits = limits return func return inner @@ -135,12 +134,11 @@ def concurrency( name: str = "", max_runs: int = 1, limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS, -) -> Callable[[Callable[P, R]], Callable[P, R]]: - def inner(func: Callable[P, R]) -> Callable[P, R]: - ## TODO: Use Protocol here to help with MyPy errors - func._concurrency_fn_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined] - func._concurrency_max_runs = max_runs # type: ignore[attr-defined] - func._concurrency_limit_strategy = limit_strategy # type: ignore[attr-defined] +) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]: + def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol: + func._concurrency_fn_name = name.lower() or str(func.__name__).lower() + func._concurrency_max_runs = max_runs + func._concurrency_limit_strategy = limit_strategy return func diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index 2c9f2c8f..bd543382 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -59,7 +59,10 @@ def inject_carrier_into_metadata( return metadata -def parse_carrier_from_metadata(metadata: dict[str, Any]) -> Context: +def parse_carrier_from_metadata(metadata: dict[str, Any] | None) -> Context | None: + if not metadata: + return None + return ( TraceContextTextMapPropagator().extract(_ctx) if (_ctx := metadata.get(OTEL_CARRIER_KEY)) diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 2ad32d31..fb74c821 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -18,7 +18,7 @@ class WorkerActionRunLoopManager: name: str action_registry: Dict[str, Callable[..., Any]] - max_runs: int + max_runs: int | None config: ClientConfig action_queue: Queue event_queue: Queue @@ -48,23 +48,23 @@ async def async_start(self, retry_count=1): self._async_start, )(retry_count=retry_count) - async def _async_start(self, retry_count=1): + async def _async_start(self, retry_count: int = 1) -> None: logger.info("starting runner...") self.loop = asyncio.get_running_loop() # needed for graceful termination k = self.loop.create_task(self._start_action_loop()) await k - def cleanup(self): + def cleanup(self) -> None: self.killing = True self.action_queue.put(STOP_LOOP) - async def wait_for_tasks(self): + async def wait_for_tasks(self) -> None: if self.runner: await self.runner.wait_for_tasks() - async def _start_action_loop(self): + async def _start_action_loop(self) -> None: self.runner = Runner( self.name, self.event_queue, @@ -88,7 +88,7 @@ async def _start_action_loop(self): async def _get_action(self): return await self.loop.run_in_executor(None, self.action_queue.get) - async def exit_gracefully(self): + async def exit_gracefully(self) -> None: if self.killing: return @@ -101,6 +101,6 @@ async def exit_gracefully(self): # task list. await asyncio.sleep(1) - def exit_forcefully(self): + def exit_forcefully(self) -> None: logger.info("forcefully exiting runner...") self.cleanup() diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index 14f7a027..144a3939 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -8,7 +8,7 @@ from enum import Enum from multiprocessing import Queue from threading import Thread, current_thread -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, TypeVar, cast from opentelemetry.trace import StatusCode @@ -18,9 +18,9 @@ from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher from hatchet_sdk.clients.run_event_listener import new_listener from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener -from hatchet_sdk.context import Context +from hatchet_sdk.context import Context # type: ignore[attr-defined] from hatchet_sdk.context.worker_context import WorkerContext -from hatchet_sdk.contracts.dispatcher_pb2 import ( +from hatchet_sdk.contracts.dispatcher_pb2 import ( # type: ignore[attr-defined] GROUP_KEY_EVENT_TYPE_COMPLETED, GROUP_KEY_EVENT_TYPE_FAILED, GROUP_KEY_EVENT_TYPE_STARTED, @@ -48,11 +48,11 @@ class Runner: def __init__( self, name: str, - event_queue: Queue, + event_queue: "Queue[Any]", max_runs: int | None = None, handle_kill: bool = True, action_registry: dict[str, Callable[..., Any]] = {}, - config: ClientConfig = {}, + config: ClientConfig = ClientConfig(), labels: dict[str, str | int] = {}, ): # We store the config so we can dynamically create clients for the dispatcher client. @@ -60,7 +60,7 @@ def __init__( self.client = new_client_raw(config) self.name = self.client.config.namespace + name self.max_runs = max_runs - self.tasks: Dict[str, asyncio.Task] = {} # Store run ids and futures + self.tasks: Dict[str, asyncio.Task[Any]] = {} # Store run ids and futures self.contexts: Dict[str, Context] = {} # Store run ids and contexts self.action_registry: dict[str, Callable[..., Any]] = action_registry @@ -89,7 +89,7 @@ def __init__( def create_workflow_run_url(self, action: Action) -> str: return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}" - def run(self, action: Action): + def run(self, action: Action) -> None: ctx = parse_carrier_from_metadata(action.additional_metadata) with self.otel_tracer.start_as_current_span( @@ -122,8 +122,8 @@ def run(self, action: Action): span.add_event(log) logger.error(log) - def step_run_callback(self, action: Action): - def inner_callback(task: asyncio.Task): + def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]: + def inner_callback(task: asyncio.Task[Any]) -> None: self.cleanup_run_id(action.step_run_id) errored = False @@ -164,8 +164,10 @@ def inner_callback(task: asyncio.Task): return inner_callback - def group_key_run_callback(self, action: Action): - def inner_callback(task: asyncio.Task): + def group_key_run_callback( + self, action: Action + ) -> Callable[[asyncio.Task[Any]], None]: + def inner_callback(task: asyncio.Task[Any]) -> None: self.cleanup_run_id(action.get_group_key_run_id) errored = False @@ -204,7 +206,10 @@ def inner_callback(task: asyncio.Task): return inner_callback - def thread_action_func(self, context, action_func, action: Action): + ## TODO: Stricter type hinting here + def thread_action_func( + self, context: Context, action_func: Callable[..., Any], action: Action + ) -> Any: if action.step_run_id is not None and action.step_run_id != "": self.threads[action.step_run_id] = current_thread() elif ( @@ -215,10 +220,15 @@ def thread_action_func(self, context, action_func, action: Action): return action_func(context) + ## TODO: Stricter type hinting here # We wrap all actions in an async func async def async_wrapped_action_func( - self, context: Context, action_func, action: Action, run_id: str - ): + self, + context: Context, + action_func: Callable[..., Any], + action: Action, + run_id: str, + ) -> Any: wr.set(context.workflow_run_id()) sr.set(context.step_run_id) @@ -254,7 +264,7 @@ async def async_wrapped_action_func( finally: self.cleanup_run_id(run_id) - def cleanup_run_id(self, run_id: str): + def cleanup_run_id(self, run_id: str | None) -> None: if run_id in self.tasks: del self.tasks[run_id] @@ -267,7 +277,7 @@ def cleanup_run_id(self, run_id: str): def create_context( self, action: Action, action_func: Callable[..., Any] | None ) -> Context | DurableContext: - if hasattr(action_func, "durable") and action_func.durable: + if hasattr(action_func, "durable") and getattr(action_func, "durable"): return DurableContext( action, self.dispatcher_client, @@ -292,7 +302,7 @@ def create_context( self.client.config.namespace, ) - async def handle_start_step_run(self, action: Action): + async def handle_start_step_run(self, action: Action) -> None: with self.otel_tracer.start_as_current_span( f"hatchet.worker.handle_start_step_run.{action.step_id}", ) as span: @@ -336,7 +346,7 @@ async def handle_start_step_run(self, action: Action): span.add_event("Finished step run") - async def handle_start_group_key_run(self, action: Action): + async def handle_start_group_key_run(self, action: Action) -> None: with self.otel_tracer.start_as_current_span( f"hatchet.worker.handle_start_step_run.{action.step_id}" ) as span: @@ -353,6 +363,7 @@ async def handle_start_group_key_run(self, action: Action): self.worker_context, self.client.config.namespace, ) + self.contexts[action.get_group_key_run_id] = context # Find the corresponding action function from the registry @@ -387,18 +398,18 @@ async def handle_start_group_key_run(self, action: Action): span.add_event("Finished group key run") - def force_kill_thread(self, thread): + def force_kill_thread(self, thread: Thread) -> None: """Terminate a python threading.Thread.""" try: if not thread.is_alive(): return - logger.info(f"Forcefully terminating thread {thread.ident}") + ident = cast(int, thread.ident) + + logger.info(f"Forcefully terminating thread {ident}") exc = ctypes.py_object(SystemExit) - res = ctypes.pythonapi.PyThreadState_SetAsyncExc( - ctypes.c_long(thread.ident), exc - ) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc) if res == 0: raise ValueError("Invalid thread ID") elif res != 1: @@ -408,7 +419,7 @@ def force_kill_thread(self, thread): ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0) raise SystemError("PyThreadState_SetAsyncExc failed") - logger.info(f"Successfully terminated thread {thread.ident}") + logger.info(f"Successfully terminated thread {ident}") # Immediately add a new thread to the thread pool, because we've actually killed a worker # in the ThreadPoolExecutor @@ -416,7 +427,7 @@ def force_kill_thread(self, thread): except Exception as e: logger.exception(f"Failed to terminate thread: {e}") - async def handle_cancel_action(self, run_id: str): + async def handle_cancel_action(self, run_id: str) -> None: with self.otel_tracer.start_as_current_span( "hatchet.worker.handle_cancel_action" ) as span: @@ -427,7 +438,9 @@ async def handle_cancel_action(self, run_id: str): # call cancel to signal the context to stop if run_id in self.contexts: context = self.contexts.get(run_id) - context.cancel() + + if context: + context.cancel() await asyncio.sleep(1) @@ -458,7 +471,7 @@ def serialize_output(self, output: Any) -> str: output_bytes = str(output) return output_bytes - async def wait_for_tasks(self): + async def wait_for_tasks(self) -> None: running = len(self.tasks.keys()) while running > 0: logger.info(f"waiting for {running} tasks to finish...") @@ -466,6 +479,6 @@ async def wait_for_tasks(self): running = len(self.tasks.keys()) -def errorWithTraceback(message: str, e: Exception): +def errorWithTraceback(message: str, e: Exception) -> str: trace = "".join(traceback.format_exception(type(e), e, e.__traceback__)) return f"{message}\n{trace}" diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index db944511..242a48b8 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -1,22 +1,30 @@ import asyncio import multiprocessing +import multiprocessing.context import os import signal import sys +from concurrent.futures import Future from dataclasses import dataclass, field from enum import Enum -from multiprocessing import Process, Queue -from typing import Any, Callable, Dict, Optional +from multiprocessing import Queue +from multiprocessing.process import BaseProcess +from types import FrameType +from typing import Any, Callable, TypeVar from hatchet_sdk.client import Client, new_client_raw -from hatchet_sdk.context import Context -from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts +from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] + CreateWorkflowVersionOpts, +) from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.v2.callable import HatchetCallable +from hatchet_sdk.v2.concurrency import ConcurrencyFunction from hatchet_sdk.worker.action_listener_process import worker_action_listener_process from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager -from hatchet_sdk.workflow import WorkflowMeta +from hatchet_sdk.workflow import WorkflowInterface + +T = TypeVar("T") class WorkerStatus(Enum): @@ -28,46 +36,58 @@ class WorkerStatus(Enum): @dataclass class WorkerStartOptions: - loop: asyncio.AbstractEventLoop = field(default=None) + loop: asyncio.AbstractEventLoop | None = field(default=None) -@dataclass class Worker: - name: str - config: ClientConfig = field(default_factory=dict) - max_runs: Optional[int] = None - debug: bool = False - labels: dict[str, str | int] = field(default_factory=dict) - handle_kill: bool = True - - client: Client = field(init=False) - tasks: Dict[str, asyncio.Task] = field(default_factory=dict) - contexts: Dict[str, Context] = field(default_factory=dict) - action_registry: Dict[str, Callable[..., Any]] = field(default_factory=dict) - killing: bool = field(init=False, default=False) - _status: WorkerStatus = field(init=False, default=WorkerStatus.INITIALIZED) - - action_listener_process: Process = field(init=False, default=None) - action_listener_health_check: asyncio.Task = field(init=False, default=None) - action_runner: WorkerActionRunLoopManager = field(init=False, default=None) - ctx = multiprocessing.get_context("spawn") - - action_queue: Queue = field(init=False, default_factory=ctx.Queue) - event_queue: Queue = field(init=False, default_factory=ctx.Queue) - - loop: asyncio.AbstractEventLoop = field(init=False, default=None) - owned_loop: bool = True - - def __post_init__(self): + def __init__( + self, + name: str, + config: ClientConfig = ClientConfig(), + max_runs: int | None = None, + labels: dict[str, str | int] = {}, + debug: bool = False, + owned_loop: bool = True, + handle_kill: bool = True, + ) -> None: + self.name = name + self.config = config + self.max_runs = max_runs + self.debug = debug + self.labels = labels + self.handle_kill = handle_kill + self.owned_loop = owned_loop + + self.client: Client + + self.action_registry: dict[str, Callable[..., Any]] = {} + self.killing: bool = False + self._status: WorkerStatus + + self.action_listener_process: BaseProcess + self.action_listener_health_check: asyncio.Task[Any] + self.action_runner: WorkerActionRunLoopManager + + self.ctx = multiprocessing.get_context("spawn") + + self.action_queue: "Queue[Any]" = self.ctx.Queue() + self.event_queue: "Queue[Any]" = self.ctx.Queue() + + self.loop: asyncio.AbstractEventLoop + self.client = new_client_raw(self.config, self.debug) self.name = self.client.config.namespace + self.name - if self.owned_loop: - self._setup_signal_handlers() - def register_function(self, action: str, func: HatchetCallable): + self._setup_signal_handlers() + + def register_function( + self, action: str, func: HatchetCallable[Any] | ConcurrencyFunction + ) -> None: self.action_registry[action] = func - def register_workflow_from_opts(self, name: str, opts: CreateWorkflowVersionOpts): + def register_workflow_from_opts( + self, name: str, opts: CreateWorkflowVersionOpts + ) -> None: try: self.client.admin.put_workflow(opts.name, opts) except Exception as e: @@ -75,7 +95,7 @@ def register_workflow_from_opts(self, name: str, opts: CreateWorkflowVersionOpts logger.error(e) sys.exit(1) - def register_workflow(self, workflow: WorkflowMeta): + def register_workflow(self, workflow: WorkflowInterface) -> None: namespace = self.client.config.namespace try: @@ -87,14 +107,16 @@ def register_workflow(self, workflow: WorkflowMeta): logger.error(e) sys.exit(1) - def create_action_function(action_func): - def action_function(context): + def create_action_function( + action_func: Callable[..., Any] + ) -> Callable[..., Any]: + def action_function(context: Any) -> Any: return action_func(workflow, context) if asyncio.iscoroutinefunction(action_func): - action_function.is_coroutine = True + setattr(action_function, "is_coroutine", True) else: - action_function.is_coroutine = False + setattr(action_function, "is_coroutine", False) return action_function @@ -104,7 +126,7 @@ def action_function(context): def status(self) -> WorkerStatus: return self._status - def setup_loop(self, loop: asyncio.AbstractEventLoop = None): + def setup_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> bool: try: loop = loop or asyncio.get_running_loop() self.loop = loop @@ -118,8 +140,11 @@ def setup_loop(self, loop: asyncio.AbstractEventLoop = None): created_loop = True return created_loop - def start(self, options: WorkerStartOptions = WorkerStartOptions()): + def start( + self, options: WorkerStartOptions = WorkerStartOptions() + ) -> Future[asyncio.Task[Any] | None]: self.owned_loop = self.setup_loop(options.loop) + f = asyncio.run_coroutine_threadsafe( self.async_start(options, _from_start=True), self.loop ) @@ -136,7 +161,7 @@ async def async_start( self, options: WorkerStartOptions = WorkerStartOptions(), _from_start: bool = False, - ): + ) -> Any | None: main_pid = os.getpid() logger.info("------------------------------------------") logger.info("STARTING HATCHET...") @@ -148,13 +173,14 @@ async def async_start( logger.error( "no actions registered, register workflows or actions before starting worker" ) - return + return None # non blocking setup if not _from_start: self.setup_loop(options.loop) self.action_listener_process = self._start_listener() + self.action_runner = self._run_action_runner() self.action_listener_health_check = self.loop.create_task( self._check_listener_health() @@ -162,9 +188,9 @@ async def async_start( return await self.action_listener_health_check - def _run_action_runner(self): + def _run_action_runner(self) -> WorkerActionRunLoopManager: # Retrieve the shared queue - runner = WorkerActionRunLoopManager( + return WorkerActionRunLoopManager( self.name, self.action_registry, self.max_runs, @@ -177,10 +203,9 @@ def _run_action_runner(self): self.labels, ) - return runner - - def _start_listener(self): + def _start_listener(self) -> multiprocessing.context.SpawnProcess: action_list = [str(key) for key in self.action_registry.keys()] + try: process = self.ctx.Process( target=worker_action_listener_process, @@ -204,7 +229,7 @@ def _start_listener(self): logger.error(f"failed to start action listener: {e}") sys.exit(1) - async def _check_listener_health(self): + async def _check_listener_health(self) -> None: logger.debug("starting action listener health check...") try: while not self.killing: @@ -224,21 +249,21 @@ async def _check_listener_health(self): logger.error(f"error checking listener health: {e}") ## Cleanup methods - def _setup_signal_handlers(self): + def _setup_signal_handlers(self) -> None: signal.signal(signal.SIGTERM, self._handle_exit_signal) signal.signal(signal.SIGINT, self._handle_exit_signal) signal.signal(signal.SIGQUIT, self._handle_force_quit_signal) - def _handle_exit_signal(self, signum, frame): + def _handle_exit_signal(self, signum: int, frame: FrameType | None) -> None: sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" logger.info(f"received signal {sig_name}...") self.loop.create_task(self.exit_gracefully()) - def _handle_force_quit_signal(self, signum, frame): + def _handle_force_quit_signal(self, signum: int, frame: FrameType | None) -> None: logger.info("received SIGQUIT...") self.exit_forcefully() - async def close(self): + async def close(self) -> None: logger.info(f"closing worker '{self.name}'...") self.killing = True # self.action_queue.close() @@ -249,7 +274,7 @@ async def close(self): await self.action_listener_health_check - async def exit_gracefully(self): + async def exit_gracefully(self) -> None: logger.debug(f"gracefully stopping worker: {self.name}") if self.killing: @@ -270,7 +295,7 @@ async def exit_gracefully(self): logger.info("👋") - def exit_forcefully(self): + def exit_forcefully(self) -> None: self.killing = True logger.debug(f"forcefully stopping worker: {self.name}") @@ -286,7 +311,7 @@ def exit_forcefully(self): ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup -def register_on_worker(callable: HatchetCallable, worker: Worker): +def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None: worker.register_function(callable.get_action_name(), callable) if callable.function_on_failure is not None: diff --git a/hatchet_sdk/workflow.py b/hatchet_sdk/workflow.py index 2466bf7a..c95765e7 100644 --- a/hatchet_sdk/workflow.py +++ b/hatchet_sdk/workflow.py @@ -1,17 +1,44 @@ import functools -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, Protocol, Type, TypeVar, Union, cast from hatchet_sdk import ConcurrencyLimitStrategy -from hatchet_sdk.contracts.workflows_pb2 import ( +from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] CreateWorkflowJobOpts, CreateWorkflowStepOpts, CreateWorkflowVersionOpts, + StickyStrategy, WorkflowConcurrencyOpts, WorkflowKind, ) from hatchet_sdk.logger import logger -stepsType = List[Tuple[str, Callable[..., Any]]] + +class WorkflowStepProtocol(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + __name__: str + + _step_name: str + _step_timeout: str | None + _step_parents: list[str] + _step_retries: int | None + _step_rate_limits: list[str] | None + _step_desired_worker_labels: dict[str, str] + + _concurrency_fn_name: str + _concurrency_max_runs: int | None + _concurrency_limit_strategy: str | None + + _on_failure_step_name: str + _on_failure_step_timeout: str | None + _on_failure_step_retries: int + _on_failure_step_rate_limits: list[str] | None + + +StepsType = list[tuple[str, WorkflowStepProtocol]] + +T = TypeVar("T") +TW = TypeVar("TW", bound="WorkflowInterface") class ConcurrencyExpression: @@ -35,28 +62,46 @@ def __init__( self.limit_strategy = limit_strategy +class WorkflowInterface(Protocol): + def get_name(self, namespace: str) -> str: ... + + def get_actions(self, namespace: str) -> list[tuple[str, Callable[..., Any]]]: ... + + def get_create_opts(self, namespace: str) -> Any: ... + + on_events: list[str] + on_crons: list[str] + name: str + version: str + timeout: str + schedule_timeout: str + sticky: Union[StickyStrategy.Value, None] + default_priority: int | None + concurrency_expression: ConcurrencyExpression | None + + class WorkflowMeta(type): - def __new__(cls, name, bases, attrs): - concurrencyActions: stepsType = [ - (getattr(func, "_concurrency_fn_name"), attrs.pop(func_name)) - for func_name, func in list(attrs.items()) - if hasattr(func, "_concurrency_fn_name") - ] - steps: stepsType = [ - (getattr(func, "_step_name"), attrs.pop(func_name)) - for func_name, func in list(attrs.items()) - if hasattr(func, "_step_name") - ] - onFailureSteps: stepsType = [ - (getattr(func, "_on_failure_step_name"), attrs.pop(func_name)) - for func_name, func in list(attrs.items()) - if hasattr(func, "_on_failure_step_name") - ] + def __new__( + cls: Type["WorkflowMeta"], + name: str, + bases: tuple[type, ...], + attrs: dict[str, Any], + ) -> "WorkflowMeta": + def _create_steps_actions_list(name: str) -> StepsType: + return [ + (getattr(func, name), attrs.pop(func_name)) + for func_name, func in list(attrs.items()) + if hasattr(func, name) + ] + + concurrencyActions = _create_steps_actions_list("_concurrency_fn_name") + steps = _create_steps_actions_list("_step_name") + onFailureSteps = _create_steps_actions_list("_on_failure_step_name") # Define __init__ and get_step_order methods original_init = attrs.get("__init__") # Get the original __init__ if it exists - def __init__(self, *args, **kwargs): + def __init__(self: TW, *args: Any, **kwargs: Any) -> None: if original_init: original_init(self, *args, **kwargs) # Call original __init__ @@ -64,7 +109,7 @@ def get_service_name(namespace: str) -> str: return f"{namespace}{name.lower()}" @functools.cache - def get_actions(self, namespace: str) -> stepsType: + def get_actions(self: TW, namespace: str) -> StepsType: serviceName = get_service_name(namespace) func_actions = [ (serviceName + ":" + func_name, func) for func_name, func in steps @@ -87,8 +132,8 @@ def get_actions(self, namespace: str) -> stepsType: for step_name, step_func in steps: attrs[step_name] = step_func - def get_name(self, namespace: str): - return namespace + attrs["name"] + def get_name(self: TW, namespace: str) -> str: + return namespace + cast(str, attrs["name"]) attrs["get_name"] = get_name @@ -99,11 +144,11 @@ def get_name(self, namespace: str): default_priority = attrs["default_priority"] @functools.cache - def get_create_opts(self, namespace: str): + def get_create_opts(self: TW, namespace: str) -> CreateWorkflowVersionOpts: serviceName = get_service_name(namespace) name = self.get_name(namespace) event_triggers = [namespace + event for event in attrs["on_events"]] - createStepOpts: List[CreateWorkflowStepOpts] = [ + createStepOpts: list[CreateWorkflowStepOpts] = [ CreateWorkflowStepOpts( readable_id=step_name, action=serviceName + ":" + step_name, @@ -140,7 +185,7 @@ def get_create_opts(self, namespace: str): "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." ) - on_failure_job: List[CreateWorkflowJobOpts] | None = None + on_failure_job: list[CreateWorkflowJobOpts] | None = None if len(onFailureSteps) > 0: func_name, func = onFailureSteps[0] diff --git a/lint.sh b/lint.sh index 5ff37475..8b0263f4 100755 --- a/lint.sh +++ b/lint.sh @@ -1 +1,3 @@ -pre-commit run --all-files || pre-commit run --all-files +poetry run black . --color +poetry run isort . +poetry run mypy --config-file=pyproject.toml diff --git a/pyproject.toml b/pyproject.toml index 1c6b6744..cd63afac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hatchet-sdk" -version = "0.40.1" +version = "0.40.2" description = "" authors = ["Alexander Belanger "] readme = "README.md" @@ -69,8 +69,16 @@ extend_exclude = "hatchet_sdk/contracts/" [tool.mypy] strict = true -files = ["hatchet_sdk/hatchet.py"] +files = [ + "hatchet_sdk/hatchet.py", + "hatchet_sdk/worker/worker.py", + "hatchet_sdk/context/context.py", + "hatchet_sdk/worker/runner/runner.py", + "hatchet_sdk/workflow.py" +] +exclude = "^examples/*" follow_imports = "silent" +disable_error_code = ["unused-coroutine"] [tool.poetry.scripts] api = "examples.api.api:main"