diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 9cfa34a3..3461909a 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -26,7 +26,7 @@ jobs: virtualenvs-in-project: true - name: Install linting tools - run: poetry install --no-root --only lint + run: poetry install --no-root - name: Run Black run: poetry run black . --check --verbose --diff --color diff --git a/examples/pydantic/test_pydantic.py b/examples/pydantic/test_pydantic.py new file mode 100644 index 00000000..f4c16f0e --- /dev/null +++ b/examples/pydantic/test_pydantic.py @@ -0,0 +1,30 @@ +import pytest + +from hatchet_sdk import Hatchet + + +# requires scope module or higher for shared event loop +@pytest.mark.asyncio(scope="session") +@pytest.mark.parametrize("worker", ["pydantic"], indirect=True) +async def test_run_validation_error(hatchet: Hatchet, worker): + run = hatchet.admin.run_workflow( + "Parent", + {}, + ) + + with pytest.raises(Exception, match="1 validation error for ParentInput"): + await run.result() + + +# requires scope module or higher for shared event loop +@pytest.mark.asyncio(scope="session") +@pytest.mark.parametrize("worker", ["pydantic"], indirect=True) +async def test_run(hatchet: Hatchet, worker): + run = hatchet.admin.run_workflow( + "Parent", + {"x": "foobar"}, + ) + + result = await run.result() + + assert len(result["spawn"]) == 3 diff --git a/examples/pydantic/trigger.py b/examples/pydantic/trigger.py new file mode 100644 index 00000000..7c27b7df --- /dev/null +++ b/examples/pydantic/trigger.py @@ -0,0 +1,19 @@ +import asyncio + +from dotenv import load_dotenv + +from hatchet_sdk import new_client + + +async def main(): + load_dotenv() + hatchet = new_client() + + hatchet.admin.run_workflow( + "Parent", + {"x": "foo bar baz"}, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/pydantic/worker.py b/examples/pydantic/worker.py new file mode 100644 index 00000000..7b66e9dc --- /dev/null +++ b/examples/pydantic/worker.py @@ -0,0 +1,82 @@ +from typing import cast + +from dotenv import load_dotenv +from pydantic import BaseModel + +from hatchet_sdk import Context, Hatchet + +load_dotenv() + +hatchet = Hatchet(debug=True) + + +# ❓ Pydantic +# This workflow shows example usage of Pydantic within Hatchet +class ParentInput(BaseModel): + x: str + + +@hatchet.workflow(input_validator=ParentInput) +class Parent: + @hatchet.step(timeout="5m") + async def spawn(self, context: Context): + ## Use `typing.cast` to cast your `workflow_input` + ## to the type of your `input_validator` + input = cast(ParentInput, context.workflow_input()) ## This is a `ParentInput` + + child = await context.aio.spawn_workflow( + "Child", + {"a": 1, "b": "10"}, + ) + + return await child.result() + + +class ChildInput(BaseModel): + a: int + b: int + + +class StepResponse(BaseModel): + status: str + + +@hatchet.workflow(input_validator=ChildInput) +class Child: + @hatchet.step() + def process(self, context: Context) -> StepResponse: + ## This is an instance `ChildInput` + input = cast(ChildInput, context.workflow_input()) + + return StepResponse(status="success") + + @hatchet.step(parents=["process"]) + def process2(self, context: Context) -> StepResponse: + ## This is an instance of `StepResponse` + process_output = cast(StepResponse, context.step_output("process")) + + return {"status": "step 2 - success"} + + @hatchet.step(parents=["process2"]) + def process3(self, context: Context) -> StepResponse: + ## This is an instance of `StepResponse`, even though the + ## response of `process2` was a dictionary. Note that + ## Hatchet will attempt to parse that dictionary into + ## an object of type `StepResponse` + process_2_output = cast(StepResponse, context.step_output("process2")) + + return StepResponse(status="step 3 - success") + + +# ‼️ + + +def main(): + worker = hatchet.worker("pydantic-worker") + worker.register_workflow(Parent()) + worker.register_workflow(Child()) + worker.start() + + +if __name__ == "__main__": + main() diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 3524502a..746a346c 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -54,16 +54,10 @@ class ChildTriggerWorkflowOptions(TypedDict): sticky: bool | None = None -class WorkflowRunDict(TypedDict): - workflow_name: str - input: Any - options: Optional[dict] - - class ChildWorkflowRunDict(TypedDict): workflow_name: str input: Any - options: ChildTriggerWorkflowOptions[dict] + options: ChildTriggerWorkflowOptions key: str | None = None @@ -73,6 +67,12 @@ class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict): namespace: str | None = None +class WorkflowRunDict(TypedDict): + workflow_name: str + input: Any + options: TriggerWorkflowOptions | None + + class DedupeViolationErr(Exception): """Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value.""" @@ -260,7 +260,9 @@ async def run_workflow( @tenacity_retry async def run_workflows( - self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None + self, + workflows: list[WorkflowRunDict], + options: TriggerWorkflowOptions | None = None, ) -> List[WorkflowRunRef]: if len(workflows) == 0: raise ValueError("No workflows to run") diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index cd1c195b..fc2887bd 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -61,7 +61,7 @@ class Action: worker_id: str tenant_id: str workflow_run_id: str - get_group_key_run_id: Optional[str] + get_group_key_run_id: str job_id: str job_name: str job_run_id: str diff --git a/hatchet_sdk/clients/dispatcher/dispatcher.py b/hatchet_sdk/clients/dispatcher/dispatcher.py index b92addab..6578528b 100644 --- a/hatchet_sdk/clients/dispatcher/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -137,14 +137,14 @@ def put_overrides_data(self, data: OverridesData): return response - def release_slot(self, step_run_id: str): + def release_slot(self, step_run_id: str) -> None: self.client.ReleaseSlot( ReleaseSlotRequest(stepRunId=step_run_id), timeout=DEFAULT_REGISTER_TIMEOUT, metadata=get_metadata(self.token), ) - def refresh_timeout(self, step_run_id: str, increment_by: str): + def refresh_timeout(self, step_run_id: str, increment_by: str) -> None: self.client.RefreshTimeout( RefreshTimeoutRequest( stepRunId=step_run_id, diff --git a/hatchet_sdk/clients/rest/tenacity_utils.py b/hatchet_sdk/clients/rest/tenacity_utils.py index 42f51394..55fe6d18 100644 --- a/hatchet_sdk/clients/rest/tenacity_utils.py +++ b/hatchet_sdk/clients/rest/tenacity_utils.py @@ -1,10 +1,15 @@ +from typing import Callable, ParamSpec, TypeVar + import grpc import tenacity from hatchet_sdk.logger import logger +P = ParamSpec("P") +R = TypeVar("R") + -def tenacity_retry(func): +def tenacity_retry(func: Callable[P, R]) -> Callable[P, R]: return tenacity.retry( reraise=True, wait=tenacity.wait_exponential_jitter(), diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index 5f6c8726..ad8ecf1a 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -2,7 +2,10 @@ import json import traceback from concurrent.futures import Future, ThreadPoolExecutor -from typing import List +from typing import Any, Generic, Type, TypeVar, cast, overload +from warnings import warn + +from pydantic import BaseModel, StrictStr from hatchet_sdk.clients.events import EventClient from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry @@ -10,11 +13,13 @@ from hatchet_sdk.clients.run_event_listener import RunEventListenerClient from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener from hatchet_sdk.context.worker_context import WorkerContext -from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData -from hatchet_sdk.contracts.workflows_pb2 import ( +from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData # type: ignore +from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] BulkTriggerWorkflowRequest, TriggerWorkflowRequest, ) +from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.utils.typing import is_basemodel_subclass from hatchet_sdk.workflow_run import WorkflowRunRef from ..clients.admin import ( @@ -24,25 +29,34 @@ TriggerWorkflowOptions, WorkflowRunDict, ) -from ..clients.dispatcher.dispatcher import Action, DispatcherClient +from ..clients.dispatcher.dispatcher import ( # type: ignore[attr-defined] + Action, + DispatcherClient, +) from ..logger import logger DEFAULT_WORKFLOW_POLLING_INTERVAL = 5 # Seconds +T = TypeVar("T", bound=BaseModel) -def get_caller_file_path(): + +def get_caller_file_path() -> str: caller_frame = inspect.stack()[2] return caller_frame.filename class BaseContext: + + action: Action + spawn_index: int + def _prepare_workflow_options( self, - key: str = None, + key: str | None = None, options: ChildTriggerWorkflowOptions | None = None, - worker_id: str = None, - ): + worker_id: str | None = None, + ) -> TriggerWorkflowOptions: workflow_run_id = self.action.workflow_run_id step_run_id = self.action.step_run_id @@ -54,7 +68,8 @@ def _prepare_workflow_options( if options is not None and "additional_metadata" in options: meta = options["additional_metadata"] - trigger_options: TriggerWorkflowOptions = { + ## TODO: Pydantic here to simplify this + trigger_options: TriggerWorkflowOptions = { # type: ignore[typeddict-item] "parent_id": workflow_run_id, "parent_step_run_id": step_run_id, "child_key": key, @@ -95,9 +110,9 @@ def __init__( async def spawn_workflow( self, workflow_name: str, - input: dict = {}, - key: str = None, - options: ChildTriggerWorkflowOptions = None, + input: dict[str, Any] = {}, + key: str | None = None, + options: ChildTriggerWorkflowOptions | None = None, ) -> WorkflowRunRef: worker_id = self.worker.id() # if ( @@ -118,15 +133,15 @@ async def spawn_workflow( @tenacity_retry async def spawn_workflows( - self, child_workflow_runs: List[ChildWorkflowRunDict] - ) -> List[WorkflowRunRef]: + self, child_workflow_runs: list[ChildWorkflowRunDict] + ) -> list[WorkflowRunRef]: if len(child_workflow_runs) == 0: raise Exception("no child workflows to spawn") worker_id = self.worker.id() - bulk_trigger_workflow_runs: WorkflowRunDict = [] + bulk_trigger_workflow_runs: list[WorkflowRunDict] = [] for child_workflow_run in child_workflow_runs: workflow_name = child_workflow_run["workflow_name"] input = child_workflow_run["input"] @@ -134,7 +149,8 @@ async def spawn_workflows( key = child_workflow_run.get("key") options = child_workflow_run.get("options", {}) - trigger_options = self._prepare_workflow_options(key, options, worker_id) + ## TODO: figure out why this is failing + trigger_options = self._prepare_workflow_options(key, options, worker_id) # type: ignore[arg-type] bulk_trigger_workflow_runs.append( WorkflowRunDict( @@ -161,8 +177,10 @@ def __init__( workflow_run_event_listener: RunEventListenerClient, worker: WorkerContext, namespace: str = "", + validator_registry: dict[str, WorkflowValidator] = {}, ): self.worker = worker + self.validator_registry = validator_registry self.aio = ContextAioImpl( action, @@ -179,11 +197,11 @@ def __init__( # Check the type of action.action_payload before attempting to load it as JSON if isinstance(action.action_payload, (str, bytes, bytearray)): try: - self.data = json.loads(action.action_payload) + self.data = cast(dict[str, Any], json.loads(action.action_payload)) except Exception as e: logger.error(f"Error parsing action payload: {e}") # Assign an empty dictionary if parsing fails - self.data = {} + self.data: dict[str, Any] = {} # type: ignore[no-redef] else: # Directly assign the payload to self.data if it's already a dict self.data = ( @@ -191,6 +209,7 @@ def __init__( ) self.action = action + # FIXME: stepRunId is a legacy field, we should remove it self.stepRunId = action.step_run_id @@ -218,33 +237,53 @@ def __init__( else: self.input = self.data.get("input", {}) - def step_output(self, step: str): + def step_output(self, step: str) -> dict[str, Any] | BaseModel: + validators = self.validator_registry.get(step) + try: - return self.data["parents"][step] + parent_step_data = cast(dict[str, Any], self.data["parents"][step]) except KeyError: raise ValueError(f"Step output for '{step}' not found") + if validators and (v := validators.step_output): + return v.model_validate(parent_step_data) + + return parent_step_data + def triggered_by_event(self) -> bool: - return self.data.get("triggered_by", "") == "event" + return cast(str, self.data.get("triggered_by", "")) == "event" + + def workflow_input(self) -> dict[str, Any] | T: + if (r := self.validator_registry.get(self.action.action_id)) and ( + i := r.workflow_input + ): + return cast( + T, + i.model_validate(self.input), + ) - def workflow_input(self): return self.input - def workflow_run_id(self): + def workflow_run_id(self) -> str: return self.action.workflow_run_id - def cancel(self): + def cancel(self) -> None: logger.debug("cancelling step...") self.exit_flag = True # done returns true if the context has been cancelled - def done(self): + def done(self) -> bool: return self.exit_flag - def playground(self, name: str, default: str = None): + def playground(self, name: str, default: str | None = None) -> str | None: # if the key exists in the overrides_data field, return the value if name in self.overrides_data: - return self.overrides_data[name] + warn( + "Use of `overrides_data` is deprecated.", + DeprecationWarning, + stacklevel=1, + ) + return str(self.overrides_data[name]) caller_file = get_caller_file_path() @@ -259,7 +298,7 @@ def playground(self, name: str, default: str = None): return default - def _log(self, line: str) -> (bool, Exception): # type: ignore + def _log(self, line: str) -> tuple[bool, Exception | None]: try: self.event_client.log(message=line, step_run_id=self.stepRunId) return True, None @@ -267,7 +306,7 @@ def _log(self, line: str) -> (bool, Exception): # type: ignore # we don't want to raise an exception here, as it will kill the log thread return False, e - def log(self, line, raise_on_error: bool = False): + def log(self, line: Any, raise_on_error: bool = False) -> None: if self.stepRunId == "": return @@ -277,9 +316,9 @@ def log(self, line, raise_on_error: bool = False): except Exception: line = str(line) - future: Future = self.logger_thread_pool.submit(self._log, line) + future = self.logger_thread_pool.submit(self._log, line) - def handle_result(future: Future): + def handle_result(future: Future[tuple[bool, Exception | None]]) -> None: success, exception = future.result() if not success and exception: if raise_on_error: @@ -297,22 +336,22 @@ def handle_result(future: Future): future.add_done_callback(handle_result) - def release_slot(self): + def release_slot(self) -> None: return self.dispatcher_client.release_slot(self.stepRunId) - def _put_stream(self, data: str | bytes): + def _put_stream(self, data: str | bytes) -> None: try: self.event_client.stream(data=data, step_run_id=self.stepRunId) except Exception as e: logger.error(f"Error putting stream event: {e}") - def put_stream(self, data: str | bytes): + def put_stream(self, data: str | bytes) -> None: if self.stepRunId == "": return self.stream_event_thread_pool.submit(self._put_stream, data) - def refresh_timeout(self, increment_by: str): + def refresh_timeout(self, increment_by: str) -> None: try: return self.dispatcher_client.refresh_timeout( step_run_id=self.stepRunId, increment_by=increment_by @@ -320,28 +359,28 @@ def refresh_timeout(self, increment_by: str): except Exception as e: logger.error(f"Error refreshing timeout: {e}") - def retry_count(self): + def retry_count(self) -> int: return self.action.retry_count - def additional_metadata(self): + def additional_metadata(self) -> dict[str, Any] | None: return self.action.additional_metadata - def child_index(self): + def child_index(self) -> int | None: return self.action.child_workflow_index - def child_key(self): + def child_key(self) -> str | None: return self.action.child_workflow_key - def parent_workflow_run_id(self): + def parent_workflow_run_id(self) -> str | None: return self.action.parent_workflow_run_id - def fetch_run_failures(self): + def fetch_run_failures(self) -> list[dict[str, StrictStr]]: data = self.rest_client.workflow_run_get(self.action.workflow_run_id) other_job_runs = [ - run for run in data.job_runs if run.job_id != self.action.job_id + run for run in (data.job_runs or []) if run.job_id != self.action.job_id ] # TODO: Parse Step Runs using a Pydantic Model rather than a hand crafted dictionary - failed_step_runs = [ + return [ { "step_id": step_run.step_id, "step_run_action_name": step_run.step.action, @@ -350,7 +389,5 @@ def fetch_run_failures(self): for job_run in other_job_runs if job_run.step_runs for step_run in job_run.step_runs - if step_run.error + if step_run.error and step_run.step ] - - return failed_step_runs diff --git a/hatchet_sdk/context/worker_context.py b/hatchet_sdk/context/worker_context.py index 96cc76bc..dfd111bb 100644 --- a/hatchet_sdk/context/worker_context.py +++ b/hatchet_sdk/context/worker_context.py @@ -21,7 +21,7 @@ async def async_upsert_labels(self, labels: dict[str, str | int]): await self.client.async_upsert_worker_labels(self._worker_id, labels) self._labels.update(labels) - def id(self): + def id(self) -> str: return self._worker_id # def has_workflow(self, workflow_name: str): diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index bc51e777..3a0b5426 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -1,7 +1,8 @@ import asyncio import logging -from typing import Any, Callable, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Optional, Type, TypeVar, cast, get_type_hints +from pydantic import BaseModel from typing_extensions import deprecated from hatchet_sdk.clients.rest_client import RestApi @@ -27,14 +28,17 @@ from .clients.run_event_listener import RunEventListenerClient from .logger import logger from .worker.worker import Worker -from .workflow import ConcurrencyExpression, WorkflowMeta +from .workflow import ( + ConcurrencyExpression, + WorkflowInterface, + WorkflowMeta, + WorkflowStepProtocol, +) -P = ParamSpec("P") -R = TypeVar("R") +T = TypeVar("T", bound=BaseModel) -## TODO: Fix return type here to properly type hint the metaclass -def workflow( # type: ignore[no-untyped-def] +def workflow( name: str = "", on_events: list[str] | None = None, on_crons: list[str] | None = None, @@ -44,11 +48,12 @@ def workflow( # type: ignore[no-untyped-def] sticky: StickyStrategy = None, default_priority: int | None = None, concurrency: ConcurrencyExpression | None = None, -): + input_validator: Type[T] | None = None, +) -> Callable[[Type[WorkflowInterface]], WorkflowMeta]: on_events = on_events or [] on_crons = on_crons or [] - def inner(cls: Any) -> WorkflowMeta: + def inner(cls: Type[WorkflowInterface]) -> WorkflowMeta: cls.on_events = on_events cls.on_crons = on_crons cls.name = name or str(cls.__name__) @@ -62,7 +67,8 @@ def inner(cls: Any) -> WorkflowMeta: # with WorkflowMeta as its metaclass ## TODO: Figure out how to type this metaclass correctly - return WorkflowMeta(cls.name, cls.__bases__, dict(cls.__dict__)) # type: ignore[no-untyped-call] + cls.input_validator = input_validator + return WorkflowMeta(cls.name, cls.__bases__, dict(cls.__dict__)) return inner @@ -76,10 +82,10 @@ def step( desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, backoff_factor: float | None = None, backoff_max_seconds: int | None = None, -) -> Callable[[Callable[P, R]], Callable[P, R]]: +) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]: parents = parents or [] - def inner(func: Callable[P, R]) -> Callable[P, R]: + def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol: limits = None if rate_limits: limits = [ @@ -87,20 +93,19 @@ def inner(func: Callable[P, R]) -> Callable[P, R]: for rate_limit in rate_limits or [] ] - ## TODO: Use Protocol here to help with MyPy errors - func._step_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined] - func._step_parents = parents # type: ignore[attr-defined] - func._step_timeout = timeout # type: ignore[attr-defined] - func._step_retries = retries # type: ignore[attr-defined] - func._step_rate_limits = limits # type: ignore[attr-defined] - func._step_backoff_factor = backoff_factor # type: ignore[attr-defined] - func._step_backoff_max_seconds = backoff_max_seconds # type: ignore[attr-defined] + func._step_name = name.lower() or str(func.__name__).lower() + func._step_parents = parents + func._step_timeout = timeout + func._step_retries = retries + func._step_rate_limits = limits + func._step_backoff_factor = backoff_factor + func._step_backoff_max_seconds = backoff_max_seconds - func._step_desired_worker_labels = {} # type: ignore[attr-defined] + func._step_desired_worker_labels = {} for key, d in desired_worker_labels.items(): value = d["value"] if "value" in d else None - func._step_desired_worker_labels[key] = DesiredWorkerLabels( # type: ignore[attr-defined] + func._step_desired_worker_labels[key] = DesiredWorkerLabels( strValue=str(value) if not isinstance(value, int) else None, intValue=value if isinstance(value, int) else None, required=d["required"] if "required" in d else None, @@ -120,8 +125,8 @@ def on_failure_step( rate_limits: list[RateLimit] | None = None, backoff_factor: float | None = None, backoff_max_seconds: int | None = None, -) -> Callable[[Callable[P, R]], Callable[P, R]]: - def inner(func: Callable[P, R]) -> Callable[P, R]: +) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]: + def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol: limits = None if rate_limits: limits = [ @@ -129,13 +134,12 @@ def inner(func: Callable[P, R]) -> Callable[P, R]: for rate_limit in rate_limits or [] ] - ## TODO: Use Protocol here to help with MyPy errors - func._on_failure_step_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined] - func._on_failure_step_timeout = timeout # type: ignore[attr-defined] - func._on_failure_step_retries = retries # type: ignore[attr-defined] - func._on_failure_step_rate_limits = limits # type: ignore[attr-defined] - func._on_failure_step_backoff_factor = backoff_factor # type: ignore[attr-defined] - func._on_failure_step_backoff_max_seconds = backoff_max_seconds # type: ignore[attr-defined] + func._on_failure_step_name = name.lower() or str(func.__name__).lower() + func._on_failure_step_timeout = timeout + func._on_failure_step_retries = retries + func._on_failure_step_rate_limits = limits + func._on_failure_step_backoff_factor = backoff_factor + func._on_failure_step_backoff_max_seconds = backoff_max_seconds return func @@ -146,12 +150,11 @@ def concurrency( name: str = "", max_runs: int = 1, limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS, -) -> Callable[[Callable[P, R]], Callable[P, R]]: - def inner(func: Callable[P, R]) -> Callable[P, R]: - ## TODO: Use Protocol here to help with MyPy errors - func._concurrency_fn_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined] - func._concurrency_max_runs = max_runs # type: ignore[attr-defined] - func._concurrency_limit_strategy = limit_strategy # type: ignore[attr-defined] +) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]: + def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol: + func._concurrency_fn_name = name.lower() or str(func.__name__).lower() + func._concurrency_max_runs = max_runs + func._concurrency_limit_strategy = limit_strategy return func diff --git a/hatchet_sdk/utils/backoff.py b/hatchet_sdk/utils/backoff.py index 8dab7717..34ddac7f 100644 --- a/hatchet_sdk/utils/backoff.py +++ b/hatchet_sdk/utils/backoff.py @@ -2,7 +2,7 @@ import random -async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5): +async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5) -> None: base_time = 0.1 # starting sleep time in seconds (100 milliseconds) jitter = random.uniform(0, base_time) # add random jitter sleep_time = min(base_time * (2**attempt) + jitter, max_sleep_time) diff --git a/hatchet_sdk/utils/serialization.py b/hatchet_sdk/utils/serialization.py index d2d14f77..7eb1d13a 100644 --- a/hatchet_sdk/utils/serialization.py +++ b/hatchet_sdk/utils/serialization.py @@ -2,7 +2,10 @@ def flatten(xs: dict[str, Any], parent_key: str, separator: str) -> dict[str, Any]: - items = [] + if not xs: + return {} + + items: list[tuple[str, Any]] = [] for k, v in xs.items(): new_key = parent_key + separator + k if parent_key else k diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index 2c9f2c8f..afc398f7 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -6,9 +6,9 @@ from opentelemetry.context import Context from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import SERVICE_NAME, Resource -from opentelemetry.sdk.trace import Tracer, TracerProvider +from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace import NoOpTracerProvider +from opentelemetry.trace import NoOpTracerProvider, Tracer from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from hatchet_sdk.loader import ClientConfig @@ -44,7 +44,7 @@ def create_tracer(config: ClientConfig) -> Tracer: def create_carrier() -> dict[str, str]: - carrier = {} + carrier: dict[str, str] = {} TraceContextTextMapPropagator().inject(carrier) return carrier @@ -59,7 +59,10 @@ def inject_carrier_into_metadata( return metadata -def parse_carrier_from_metadata(metadata: dict[str, Any]) -> Context: +def parse_carrier_from_metadata(metadata: dict[str, Any] | None) -> Context | None: + if not metadata: + return None + return ( TraceContextTextMapPropagator().extract(_ctx) if (_ctx := metadata.get(OTEL_CARRIER_KEY)) diff --git a/hatchet_sdk/utils/types.py b/hatchet_sdk/utils/types.py new file mode 100644 index 00000000..30e469f7 --- /dev/null +++ b/hatchet_sdk/utils/types.py @@ -0,0 +1,8 @@ +from typing import Type + +from pydantic import BaseModel + + +class WorkflowValidator(BaseModel): + workflow_input: Type[BaseModel] | None = None + step_output: Type[BaseModel] | None = None diff --git a/hatchet_sdk/utils/typing.py b/hatchet_sdk/utils/typing.py new file mode 100644 index 00000000..4a6d968a --- /dev/null +++ b/hatchet_sdk/utils/typing.py @@ -0,0 +1,9 @@ +from typing import Any, Type, TypeGuard, TypeVar + +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) + + +def is_basemodel_subclass(model: Any) -> bool: + return isinstance(model, type) and issubclass(model, BaseModel) diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 0738c2f2..287d327f 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -42,6 +42,7 @@ def __init__( default_priority: int | None = None, ): self.func = func + self.validators = func.validators on_events = on_events or [] on_crons = on_crons or [] diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index 3d5bfc22..08508607 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -87,15 +87,13 @@ async def start(self, retry_attempt=0): try: self.dispatcher_client = new_dispatcher(self.config) - self.listener: ActionListener = ( - await self.dispatcher_client.get_action_listener( - GetActionListenerRequest( - worker_name=self.name, - services=["default"], - actions=self.actions, - max_runs=self.max_runs, - _labels=self.labels, - ) + self.listener = await self.dispatcher_client.get_action_listener( + GetActionListenerRequest( + worker_name=self.name, + services=["default"], + actions=self.actions, + max_runs=self.max_runs, + _labels=self.labels, ) ) diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 2ad32d31..27ed788c 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -2,23 +2,28 @@ import logging from dataclasses import dataclass, field from multiprocessing import Queue -from typing import Any, Callable, Dict +from typing import Callable, TypeVar +from hatchet_sdk import Context from hatchet_sdk.client import Client, new_client_raw from hatchet_sdk.clients.dispatcher.action_listener import Action from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger +from hatchet_sdk.utils.types import WorkflowValidator from hatchet_sdk.worker.runner.runner import Runner from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs STOP_LOOP = "STOP_LOOP" +T = TypeVar("T") + @dataclass class WorkerActionRunLoopManager: name: str - action_registry: Dict[str, Callable[..., Any]] - max_runs: int + action_registry: dict[str, Callable[[Context], T]] + validator_registry: dict[str, WorkflowValidator] + max_runs: int | None config: ClientConfig action_queue: Queue event_queue: Queue @@ -48,29 +53,30 @@ async def async_start(self, retry_count=1): self._async_start, )(retry_count=retry_count) - async def _async_start(self, retry_count=1): + async def _async_start(self, retry_count: int = 1) -> None: logger.info("starting runner...") self.loop = asyncio.get_running_loop() # needed for graceful termination k = self.loop.create_task(self._start_action_loop()) await k - def cleanup(self): + def cleanup(self) -> None: self.killing = True self.action_queue.put(STOP_LOOP) - async def wait_for_tasks(self): + async def wait_for_tasks(self) -> None: if self.runner: await self.runner.wait_for_tasks() - async def _start_action_loop(self): + async def _start_action_loop(self) -> None: self.runner = Runner( self.name, self.event_queue, self.max_runs, self.handle_kill, self.action_registry, + self.validator_registry, self.config, self.labels, ) @@ -88,7 +94,7 @@ async def _start_action_loop(self): async def _get_action(self): return await self.loop.run_in_executor(None, self.action_queue.get) - async def exit_gracefully(self): + async def exit_gracefully(self) -> None: if self.killing: return @@ -101,6 +107,6 @@ async def exit_gracefully(self): # task list. await asyncio.sleep(1) - def exit_forcefully(self): + def exit_forcefully(self) -> None: logger.info("forcefully exiting runner...") self.cleanup() diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index 14f7a027..0a16f867 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -8,9 +8,10 @@ from enum import Enum from multiprocessing import Queue from threading import Thread, current_thread -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Literal, Type, TypeVar, cast, overload from opentelemetry.trace import StatusCode +from pydantic import BaseModel from hatchet_sdk.client import new_client_raw from hatchet_sdk.clients.admin import new_admin @@ -18,9 +19,9 @@ from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher from hatchet_sdk.clients.run_event_listener import new_listener from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener -from hatchet_sdk.context import Context +from hatchet_sdk.context import Context # type: ignore[attr-defined] from hatchet_sdk.context.worker_context import WorkerContext -from hatchet_sdk.contracts.dispatcher_pb2 import ( +from hatchet_sdk.contracts.dispatcher_pb2 import ( # type: ignore[attr-defined] GROUP_KEY_EVENT_TYPE_COMPLETED, GROUP_KEY_EVENT_TYPE_FAILED, GROUP_KEY_EVENT_TYPE_STARTED, @@ -32,6 +33,7 @@ from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.utils.tracing import create_tracer, parse_carrier_from_metadata +from hatchet_sdk.utils.types import WorkflowValidator from hatchet_sdk.v2.callable import DurableContext from hatchet_sdk.worker.action_listener_process import ActionEvent from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr @@ -48,11 +50,12 @@ class Runner: def __init__( self, name: str, - event_queue: Queue, + event_queue: "Queue[Any]", max_runs: int | None = None, handle_kill: bool = True, action_registry: dict[str, Callable[..., Any]] = {}, - config: ClientConfig = {}, + validator_registry: dict[str, WorkflowValidator] = {}, + config: ClientConfig = ClientConfig(), labels: dict[str, str | int] = {}, ): # We store the config so we can dynamically create clients for the dispatcher client. @@ -60,9 +63,10 @@ def __init__( self.client = new_client_raw(config) self.name = self.client.config.namespace + name self.max_runs = max_runs - self.tasks: Dict[str, asyncio.Task] = {} # Store run ids and futures - self.contexts: Dict[str, Context] = {} # Store run ids and contexts + self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures + self.contexts: dict[str, Context] = {} # Store run ids and contexts self.action_registry: dict[str, Callable[..., Any]] = action_registry + self.validator_registry = validator_registry self.event_queue = event_queue @@ -89,7 +93,7 @@ def __init__( def create_workflow_run_url(self, action: Action) -> str: return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}" - def run(self, action: Action): + def run(self, action: Action) -> None: ctx = parse_carrier_from_metadata(action.additional_metadata) with self.otel_tracer.start_as_current_span( @@ -122,8 +126,8 @@ def run(self, action: Action): span.add_event(log) logger.error(log) - def step_run_callback(self, action: Action): - def inner_callback(task: asyncio.Task): + def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]: + def inner_callback(task: asyncio.Task[Any]) -> None: self.cleanup_run_id(action.step_run_id) errored = False @@ -164,8 +168,10 @@ def inner_callback(task: asyncio.Task): return inner_callback - def group_key_run_callback(self, action: Action): - def inner_callback(task: asyncio.Task): + def group_key_run_callback( + self, action: Action + ) -> Callable[[asyncio.Task[Any]], None]: + def inner_callback(task: asyncio.Task[Any]) -> None: self.cleanup_run_id(action.get_group_key_run_id) errored = False @@ -204,7 +210,10 @@ def inner_callback(task: asyncio.Task): return inner_callback - def thread_action_func(self, context, action_func, action: Action): + ## TODO: Stricter type hinting here + def thread_action_func( + self, context: Context, action_func: Callable[..., Any], action: Action + ) -> Any: if action.step_run_id is not None and action.step_run_id != "": self.threads[action.step_run_id] = current_thread() elif ( @@ -215,10 +224,15 @@ def thread_action_func(self, context, action_func, action: Action): return action_func(context) + ## TODO: Stricter type hinting here # We wrap all actions in an async func async def async_wrapped_action_func( - self, context: Context, action_func, action: Action, run_id: str - ): + self, + context: Context, + action_func: Callable[..., Any], + action: Action, + run_id: str, + ) -> Any: wr.set(context.workflow_run_id()) sr.set(context.step_run_id) @@ -240,9 +254,7 @@ async def async_wrapped_action_func( ) loop = asyncio.get_event_loop() - res = await loop.run_in_executor(self.thread_pool, pfunc) - - return res + return await loop.run_in_executor(self.thread_pool, pfunc) except Exception as e: logger.error( errorWithTraceback( @@ -254,7 +266,7 @@ async def async_wrapped_action_func( finally: self.cleanup_run_id(run_id) - def cleanup_run_id(self, run_id: str): + def cleanup_run_id(self, run_id: str | None) -> None: if run_id in self.tasks: del self.tasks[run_id] @@ -267,7 +279,7 @@ def cleanup_run_id(self, run_id: str): def create_context( self, action: Action, action_func: Callable[..., Any] | None ) -> Context | DurableContext: - if hasattr(action_func, "durable") and action_func.durable: + if hasattr(action_func, "durable") and getattr(action_func, "durable"): return DurableContext( action, self.dispatcher_client, @@ -278,6 +290,7 @@ def create_context( self.workflow_run_event_listener, self.worker_context, self.client.config.namespace, + validator_registry=self.validator_registry, ) return Context( @@ -290,9 +303,10 @@ def create_context( self.workflow_run_event_listener, self.worker_context, self.client.config.namespace, + validator_registry=self.validator_registry, ) - async def handle_start_step_run(self, action: Action): + async def handle_start_step_run(self, action: Action) -> None: with self.otel_tracer.start_as_current_span( f"hatchet.worker.handle_start_step_run.{action.step_id}", ) as span: @@ -336,7 +350,7 @@ async def handle_start_step_run(self, action: Action): span.add_event("Finished step run") - async def handle_start_group_key_run(self, action: Action): + async def handle_start_group_key_run(self, action: Action) -> None: with self.otel_tracer.start_as_current_span( f"hatchet.worker.handle_start_step_run.{action.step_id}" ) as span: @@ -353,6 +367,7 @@ async def handle_start_group_key_run(self, action: Action): self.worker_context, self.client.config.namespace, ) + self.contexts[action.get_group_key_run_id] = context # Find the corresponding action function from the registry @@ -387,18 +402,18 @@ async def handle_start_group_key_run(self, action: Action): span.add_event("Finished group key run") - def force_kill_thread(self, thread): + def force_kill_thread(self, thread: Thread) -> None: """Terminate a python threading.Thread.""" try: if not thread.is_alive(): return - logger.info(f"Forcefully terminating thread {thread.ident}") + ident = cast(int, thread.ident) + + logger.info(f"Forcefully terminating thread {ident}") exc = ctypes.py_object(SystemExit) - res = ctypes.pythonapi.PyThreadState_SetAsyncExc( - ctypes.c_long(thread.ident), exc - ) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc) if res == 0: raise ValueError("Invalid thread ID") elif res != 1: @@ -408,7 +423,7 @@ def force_kill_thread(self, thread): ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0) raise SystemError("PyThreadState_SetAsyncExc failed") - logger.info(f"Successfully terminated thread {thread.ident}") + logger.info(f"Successfully terminated thread {ident}") # Immediately add a new thread to the thread pool, because we've actually killed a worker # in the ThreadPoolExecutor @@ -416,7 +431,7 @@ def force_kill_thread(self, thread): except Exception as e: logger.exception(f"Failed to terminate thread: {e}") - async def handle_cancel_action(self, run_id: str): + async def handle_cancel_action(self, run_id: str) -> None: with self.otel_tracer.start_as_current_span( "hatchet.worker.handle_cancel_action" ) as span: @@ -427,7 +442,9 @@ async def handle_cancel_action(self, run_id: str): # call cancel to signal the context to stop if run_id in self.contexts: context = self.contexts.get(run_id) - context.cancel() + + if context: + context.cancel() await asyncio.sleep(1) @@ -449,16 +466,20 @@ async def handle_cancel_action(self, run_id: str): span.add_event(f"Finished cancelling run id: {run_id}") def serialize_output(self, output: Any) -> str: - output_bytes = "" + + if isinstance(output, BaseModel): + return output.model_dump_json() + if output is not None: try: - output_bytes = json.dumps(output) + return json.dumps(output) except Exception as e: logger.error(f"Could not serialize output: {e}") - output_bytes = str(output) - return output_bytes + return str(output) + + return "" - async def wait_for_tasks(self): + async def wait_for_tasks(self) -> None: running = len(self.tasks.keys()) while running > 0: logger.info(f"waiting for {running} tasks to finish...") @@ -466,6 +487,6 @@ async def wait_for_tasks(self): running = len(self.tasks.keys()) -def errorWithTraceback(message: str, e: Exception): +def errorWithTraceback(message: str, e: Exception) -> str: trace = "".join(traceback.format_exception(type(e), e, e.__traceback__)) return f"{message}\n{trace}" diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index db944511..9beacd47 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -1,22 +1,32 @@ import asyncio import multiprocessing +import multiprocessing.context import os import signal import sys +from concurrent.futures import Future from dataclasses import dataclass, field from enum import Enum -from multiprocessing import Process, Queue -from typing import Any, Callable, Dict, Optional +from multiprocessing import Queue +from multiprocessing.process import BaseProcess +from types import FrameType +from typing import Any, Callable, TypeVar, get_type_hints +from hatchet_sdk import Context from hatchet_sdk.client import Client, new_client_raw -from hatchet_sdk.context import Context -from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts +from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] + CreateWorkflowVersionOpts, +) from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger +from hatchet_sdk.utils.types import WorkflowValidator from hatchet_sdk.v2.callable import HatchetCallable +from hatchet_sdk.v2.concurrency import ConcurrencyFunction from hatchet_sdk.worker.action_listener_process import worker_action_listener_process from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager -from hatchet_sdk.workflow import WorkflowMeta +from hatchet_sdk.workflow import WorkflowInterface + +T = TypeVar("T") class WorkerStatus(Enum): @@ -28,46 +38,60 @@ class WorkerStatus(Enum): @dataclass class WorkerStartOptions: - loop: asyncio.AbstractEventLoop = field(default=None) + loop: asyncio.AbstractEventLoop | None = field(default=None) -@dataclass class Worker: - name: str - config: ClientConfig = field(default_factory=dict) - max_runs: Optional[int] = None - debug: bool = False - labels: dict[str, str | int] = field(default_factory=dict) - handle_kill: bool = True - - client: Client = field(init=False) - tasks: Dict[str, asyncio.Task] = field(default_factory=dict) - contexts: Dict[str, Context] = field(default_factory=dict) - action_registry: Dict[str, Callable[..., Any]] = field(default_factory=dict) - killing: bool = field(init=False, default=False) - _status: WorkerStatus = field(init=False, default=WorkerStatus.INITIALIZED) - - action_listener_process: Process = field(init=False, default=None) - action_listener_health_check: asyncio.Task = field(init=False, default=None) - action_runner: WorkerActionRunLoopManager = field(init=False, default=None) - ctx = multiprocessing.get_context("spawn") - - action_queue: Queue = field(init=False, default_factory=ctx.Queue) - event_queue: Queue = field(init=False, default_factory=ctx.Queue) - - loop: asyncio.AbstractEventLoop = field(init=False, default=None) - owned_loop: bool = True - - def __post_init__(self): + def __init__( + self, + name: str, + config: ClientConfig = ClientConfig(), + max_runs: int | None = None, + labels: dict[str, str | int] = {}, + debug: bool = False, + owned_loop: bool = True, + handle_kill: bool = True, + ) -> None: + self.name = name + self.config = config + self.max_runs = max_runs + self.debug = debug + self.labels = labels + self.handle_kill = handle_kill + self.owned_loop = owned_loop + + self.client: Client + + self.action_registry: dict[str, Callable[[Context], T]] = {} + self.validator_registry: dict[str, WorkflowValidator] = {} + + self.killing: bool = False + self._status: WorkerStatus + + self.action_listener_process: BaseProcess + self.action_listener_health_check: asyncio.Task[Any] + self.action_runner: WorkerActionRunLoopManager + + self.ctx = multiprocessing.get_context("spawn") + + self.action_queue: "Queue[Any]" = self.ctx.Queue() + self.event_queue: "Queue[Any]" = self.ctx.Queue() + + self.loop: asyncio.AbstractEventLoop + self.client = new_client_raw(self.config, self.debug) self.name = self.client.config.namespace + self.name - if self.owned_loop: - self._setup_signal_handlers() - def register_function(self, action: str, func: HatchetCallable): + self._setup_signal_handlers() + + def register_function( + self, action: str, func: HatchetCallable[Any] | ConcurrencyFunction + ) -> None: self.action_registry[action] = func - def register_workflow_from_opts(self, name: str, opts: CreateWorkflowVersionOpts): + def register_workflow_from_opts( + self, name: str, opts: CreateWorkflowVersionOpts + ) -> None: try: self.client.admin.put_workflow(opts.name, opts) except Exception as e: @@ -75,7 +99,7 @@ def register_workflow_from_opts(self, name: str, opts: CreateWorkflowVersionOpts logger.error(e) sys.exit(1) - def register_workflow(self, workflow: WorkflowMeta): + def register_workflow(self, workflow: WorkflowInterface) -> None: namespace = self.client.config.namespace try: @@ -87,24 +111,30 @@ def register_workflow(self, workflow: WorkflowMeta): logger.error(e) sys.exit(1) - def create_action_function(action_func): - def action_function(context): + def create_action_function( + action_func: Callable[..., T] + ) -> Callable[[Context], T]: + def action_function(context: Context) -> T: return action_func(workflow, context) if asyncio.iscoroutinefunction(action_func): - action_function.is_coroutine = True + setattr(action_function, "is_coroutine", True) else: - action_function.is_coroutine = False + setattr(action_function, "is_coroutine", False) return action_function for action_name, action_func in workflow.get_actions(namespace): self.action_registry[action_name] = create_action_function(action_func) + return_type = get_type_hints(action_func).get("return") + self.validator_registry[action_name] = WorkflowValidator( + workflow_input=workflow.input_validator, step_output=return_type + ) def status(self) -> WorkerStatus: return self._status - def setup_loop(self, loop: asyncio.AbstractEventLoop = None): + def setup_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> bool: try: loop = loop or asyncio.get_running_loop() self.loop = loop @@ -118,17 +148,22 @@ def setup_loop(self, loop: asyncio.AbstractEventLoop = None): created_loop = True return created_loop - def start(self, options: WorkerStartOptions = WorkerStartOptions()): + def start( + self, options: WorkerStartOptions = WorkerStartOptions() + ) -> Future[asyncio.Task[Any] | None]: self.owned_loop = self.setup_loop(options.loop) + f = asyncio.run_coroutine_threadsafe( self.async_start(options, _from_start=True), self.loop ) + # start the loop and wait until its closed if self.owned_loop: self.loop.run_forever() if self.handle_kill: sys.exit(0) + return f ## Start methods @@ -136,7 +171,7 @@ async def async_start( self, options: WorkerStartOptions = WorkerStartOptions(), _from_start: bool = False, - ): + ) -> Any | None: main_pid = os.getpid() logger.info("------------------------------------------") logger.info("STARTING HATCHET...") @@ -148,25 +183,28 @@ async def async_start( logger.error( "no actions registered, register workflows or actions before starting worker" ) - return + return None # non blocking setup if not _from_start: self.setup_loop(options.loop) self.action_listener_process = self._start_listener() + self.action_runner = self._run_action_runner() + self.action_listener_health_check = self.loop.create_task( self._check_listener_health() ) return await self.action_listener_health_check - def _run_action_runner(self): + def _run_action_runner(self) -> WorkerActionRunLoopManager: # Retrieve the shared queue - runner = WorkerActionRunLoopManager( + return WorkerActionRunLoopManager( self.name, self.action_registry, + self.validator_registry, self.max_runs, self.config, self.action_queue, @@ -177,10 +215,9 @@ def _run_action_runner(self): self.labels, ) - return runner - - def _start_listener(self): + def _start_listener(self) -> multiprocessing.context.SpawnProcess: action_list = [str(key) for key in self.action_registry.keys()] + try: process = self.ctx.Process( target=worker_action_listener_process, @@ -204,7 +241,7 @@ def _start_listener(self): logger.error(f"failed to start action listener: {e}") sys.exit(1) - async def _check_listener_health(self): + async def _check_listener_health(self) -> None: logger.debug("starting action listener health check...") try: while not self.killing: @@ -224,21 +261,21 @@ async def _check_listener_health(self): logger.error(f"error checking listener health: {e}") ## Cleanup methods - def _setup_signal_handlers(self): + def _setup_signal_handlers(self) -> None: signal.signal(signal.SIGTERM, self._handle_exit_signal) signal.signal(signal.SIGINT, self._handle_exit_signal) signal.signal(signal.SIGQUIT, self._handle_force_quit_signal) - def _handle_exit_signal(self, signum, frame): + def _handle_exit_signal(self, signum: int, frame: FrameType | None) -> None: sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" logger.info(f"received signal {sig_name}...") self.loop.create_task(self.exit_gracefully()) - def _handle_force_quit_signal(self, signum, frame): + def _handle_force_quit_signal(self, signum: int, frame: FrameType | None) -> None: logger.info("received SIGQUIT...") self.exit_forcefully() - async def close(self): + async def close(self) -> None: logger.info(f"closing worker '{self.name}'...") self.killing = True # self.action_queue.close() @@ -249,7 +286,7 @@ async def close(self): await self.action_listener_health_check - async def exit_gracefully(self): + async def exit_gracefully(self) -> None: logger.debug(f"gracefully stopping worker: {self.name}") if self.killing: @@ -270,7 +307,7 @@ async def exit_gracefully(self): logger.info("👋") - def exit_forcefully(self): + def exit_forcefully(self) -> None: self.killing = True logger.debug(f"forcefully stopping worker: {self.name}") @@ -286,7 +323,7 @@ def exit_forcefully(self): ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup -def register_on_worker(callable: HatchetCallable, worker: Worker): +def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None: worker.register_function(callable.get_action_name(), callable) if callable.function_on_failure is not None: diff --git a/hatchet_sdk/workflow.py b/hatchet_sdk/workflow.py index 14f15a98..6ce7e0c9 100644 --- a/hatchet_sdk/workflow.py +++ b/hatchet_sdk/workflow.py @@ -1,17 +1,51 @@ import functools -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, Protocol, Type, TypeVar, Union, cast, get_type_hints + +from pydantic import BaseModel from hatchet_sdk import ConcurrencyLimitStrategy -from hatchet_sdk.contracts.workflows_pb2 import ( +from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] CreateWorkflowJobOpts, CreateWorkflowStepOpts, CreateWorkflowVersionOpts, + StickyStrategy, WorkflowConcurrencyOpts, WorkflowKind, ) from hatchet_sdk.logger import logger +from hatchet_sdk.utils.typing import is_basemodel_subclass + + +class WorkflowStepProtocol(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + __name__: str + + _step_name: str + _step_timeout: str | None + _step_parents: list[str] + _step_retries: int | None + _step_rate_limits: list[str] | None + _step_desired_worker_labels: dict[str, str] + _step_backoff_factor: float | None + _step_backoff_max_seconds: int | None + + _concurrency_fn_name: str + _concurrency_max_runs: int | None + _concurrency_limit_strategy: str | None + + _on_failure_step_name: str + _on_failure_step_timeout: str | None + _on_failure_step_retries: int + _on_failure_step_rate_limits: list[str] | None + _on_failure_step_backoff_factor: float | None + _on_failure_step_backoff_max_seconds: int | None + + +StepsType = list[tuple[str, WorkflowStepProtocol]] -stepsType = List[Tuple[str, Callable[..., Any]]] +T = TypeVar("T") +TW = TypeVar("TW", bound="WorkflowInterface") class ConcurrencyExpression: @@ -35,28 +69,48 @@ def __init__( self.limit_strategy = limit_strategy +class WorkflowInterface(Protocol): + def get_name(self, namespace: str) -> str: ... + + def get_actions(self, namespace: str) -> list[tuple[str, Callable[..., Any]]]: ... + + def get_create_opts(self, namespace: str) -> Any: ... + + on_events: list[str] + on_crons: list[str] + name: str + version: str + timeout: str + schedule_timeout: str + sticky: Union[StickyStrategy.Value, None] + default_priority: int | None + concurrency_expression: ConcurrencyExpression | None + input_validator: Type[BaseModel] | None + + class WorkflowMeta(type): - def __new__(cls, name, bases, attrs): - concurrencyActions: stepsType = [ - (getattr(func, "_concurrency_fn_name"), attrs.pop(func_name)) - for func_name, func in list(attrs.items()) - if hasattr(func, "_concurrency_fn_name") - ] - steps: stepsType = [ - (getattr(func, "_step_name"), attrs.pop(func_name)) - for func_name, func in list(attrs.items()) - if hasattr(func, "_step_name") - ] - onFailureSteps: stepsType = [ - (getattr(func, "_on_failure_step_name"), attrs.pop(func_name)) - for func_name, func in list(attrs.items()) - if hasattr(func, "_on_failure_step_name") - ] + def __new__( + cls: Type["WorkflowMeta"], + name: str, + bases: tuple[type, ...], + attrs: dict[str, Any], + ) -> "WorkflowMeta": + def _create_steps_actions_list(name: str) -> StepsType: + return [ + (getattr(func, name), attrs.pop(func_name)) + for func_name, func in list(attrs.items()) + if hasattr(func, name) + ] + + concurrencyActions = _create_steps_actions_list("_concurrency_fn_name") + steps = _create_steps_actions_list("_step_name") + + onFailureSteps = _create_steps_actions_list("_on_failure_step_name") # Define __init__ and get_step_order methods original_init = attrs.get("__init__") # Get the original __init__ if it exists - def __init__(self, *args, **kwargs): + def __init__(self: TW, *args: Any, **kwargs: Any) -> None: if original_init: original_init(self, *args, **kwargs) # Call original __init__ @@ -64,7 +118,7 @@ def get_service_name(namespace: str) -> str: return f"{namespace}{name.lower()}" @functools.cache - def get_actions(self, namespace: str) -> stepsType: + def get_actions(self: TW, namespace: str) -> StepsType: serviceName = get_service_name(namespace) func_actions = [ (serviceName + ":" + func_name, func) for func_name, func in steps @@ -87,8 +141,8 @@ def get_actions(self, namespace: str) -> stepsType: for step_name, step_func in steps: attrs[step_name] = step_func - def get_name(self, namespace: str): - return namespace + attrs["name"] + def get_name(self: TW, namespace: str) -> str: + return namespace + cast(str, attrs["name"]) attrs["get_name"] = get_name @@ -99,11 +153,11 @@ def get_name(self, namespace: str): default_priority = attrs["default_priority"] @functools.cache - def get_create_opts(self, namespace: str): + def get_create_opts(self: TW, namespace: str) -> CreateWorkflowVersionOpts: serviceName = get_service_name(namespace) name = self.get_name(namespace) event_triggers = [namespace + event for event in attrs["on_events"]] - createStepOpts: List[CreateWorkflowStepOpts] = [ + createStepOpts: list[CreateWorkflowStepOpts] = [ CreateWorkflowStepOpts( readable_id=step_name, action=serviceName + ":" + step_name, @@ -142,7 +196,7 @@ def get_create_opts(self, namespace: str): "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." ) - on_failure_job: List[CreateWorkflowJobOpts] | None = None + on_failure_job: list[CreateWorkflowJobOpts] | None = None if len(onFailureSteps) > 0: func_name, func = onFailureSteps[0] diff --git a/lint.sh b/lint.sh index 5ff37475..8b0263f4 100755 --- a/lint.sh +++ b/lint.sh @@ -1 +1,3 @@ -pre-commit run --all-files || pre-commit run --all-files +poetry run black . --color +poetry run isort . +poetry run mypy --config-file=pyproject.toml diff --git a/pyproject.toml b/pyproject.toml index 6f03ac8c..65703a53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hatchet-sdk" -version = "0.42.0" +version = "0.42.1" description = "" authors = ["Alexander Belanger "] readme = "README.md" @@ -69,8 +69,20 @@ extend_exclude = "hatchet_sdk/contracts/" [tool.mypy] strict = true -files = ["hatchet_sdk/hatchet.py"] +files = [ + "hatchet_sdk/hatchet.py", + "hatchet_sdk/worker/worker.py", + "hatchet_sdk/context/context.py", + "hatchet_sdk/worker/runner/runner.py", + "hatchet_sdk/workflow.py", + "hatchet_sdk/utils/serialization.py", + "hatchet_sdk/utils/tracing.py", + "hatchet_sdk/utils/types.py", + "hatchet_sdk/utils/backoff.py" +] +exclude = "^examples/*" follow_imports = "silent" +disable_error_code = ["unused-coroutine"] [tool.poetry.scripts] api = "examples.api.api:main" @@ -92,4 +104,5 @@ timeout = "examples.timeout.worker:main" blocked = "examples.blocked_async.worker:main" existing_loop = "examples.worker_existing_loop.worker:main" bulk_fanout = "examples.bulk_fanout.worker:main" -retries_with_backoff = "examples.retries_with_backoff.worker:main" \ No newline at end of file +retries_with_backoff = "examples.retries_with_backoff.worker:main" +pydantic = "examples.pydantic.worker:main"