Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MyPy Fixes for Worker, Workflow metaclass, Context, and some others #273

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
virtualenvs-in-project: true

- name: Install linting tools
run: poetry install --no-root --only lint
run: poetry install --no-root

- name: Run Black
run: poetry run black . --check --verbose --diff --color
Expand Down
18 changes: 10 additions & 8 deletions hatchet_sdk/clients/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,10 @@ class ChildTriggerWorkflowOptions(TypedDict):
sticky: bool | None = None


class WorkflowRunDict(TypedDict):
workflow_name: str
input: Any
options: Optional[dict]


class ChildWorkflowRunDict(TypedDict):
workflow_name: str
input: Any
options: ChildTriggerWorkflowOptions[dict]
options: ChildTriggerWorkflowOptions
key: str | None = None


Expand All @@ -73,6 +67,12 @@ class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict):
namespace: str | None = None


class WorkflowRunDict(TypedDict):
workflow_name: str
input: Any
options: TriggerWorkflowOptions | None


class DedupeViolationErr(Exception):
"""Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""

Expand Down Expand Up @@ -260,7 +260,9 @@ async def run_workflow(

@tenacity_retry
async def run_workflows(
self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
self,
workflows: list[WorkflowRunDict],
options: TriggerWorkflowOptions | None = None,
) -> List[WorkflowRunRef]:
if len(workflows) == 0:
raise ValueError("No workflows to run")
Expand Down
2 changes: 1 addition & 1 deletion hatchet_sdk/clients/dispatcher/action_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Action:
worker_id: str
tenant_id: str
workflow_run_id: str
get_group_key_run_id: Optional[str]
get_group_key_run_id: str
hatchet-temporary marked this conversation as resolved.
Show resolved Hide resolved
job_id: str
job_name: str
job_run_id: str
Expand Down
4 changes: 2 additions & 2 deletions hatchet_sdk/clients/dispatcher/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def put_overrides_data(self, data: OverridesData):

return response

def release_slot(self, step_run_id: str):
def release_slot(self, step_run_id: str) -> None:
self.client.ReleaseSlot(
ReleaseSlotRequest(stepRunId=step_run_id),
timeout=DEFAULT_REGISTER_TIMEOUT,
metadata=get_metadata(self.token),
)

def refresh_timeout(self, step_run_id: str, increment_by: str):
def refresh_timeout(self, step_run_id: str, increment_by: str) -> None:
self.client.RefreshTimeout(
RefreshTimeoutRequest(
stepRunId=step_run_id,
Expand Down
7 changes: 6 additions & 1 deletion hatchet_sdk/clients/rest/tenacity_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Callable, ParamSpec, TypeVar

import grpc
import tenacity

from hatchet_sdk.logger import logger

P = ParamSpec("P")
R = TypeVar("R")


def tenacity_retry(func):
def tenacity_retry(func: Callable[P, R]) -> Callable[P, R]:
return tenacity.retry(
reraise=True,
wait=tenacity.wait_exponential_jitter(),
Expand Down
101 changes: 55 additions & 46 deletions hatchet_sdk/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import json
import traceback
from concurrent.futures import Future, ThreadPoolExecutor
from typing import List
from typing import Any, cast

from pydantic import StrictStr

from hatchet_sdk.clients.events import EventClient
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
from hatchet_sdk.clients.rest_client import RestApi
from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.context.worker_context import WorkerContext
from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData
from hatchet_sdk.contracts.workflows_pb2 import (
from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData # type: ignore
from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
BulkTriggerWorkflowRequest,
TriggerWorkflowRequest,
)
Expand All @@ -24,25 +26,32 @@
TriggerWorkflowOptions,
WorkflowRunDict,
)
from ..clients.dispatcher.dispatcher import Action, DispatcherClient
from ..clients.dispatcher.dispatcher import ( # type: ignore[attr-defined]
Action,
DispatcherClient,
)
from ..logger import logger

DEFAULT_WORKFLOW_POLLING_INTERVAL = 5 # Seconds


def get_caller_file_path():
def get_caller_file_path() -> str:
caller_frame = inspect.stack()[2]

return caller_frame.filename


class BaseContext:

action: Action
spawn_index: int

def _prepare_workflow_options(
self,
key: str = None,
key: str | None = None,
options: ChildTriggerWorkflowOptions | None = None,
worker_id: str = None,
):
worker_id: str | None = None,
) -> TriggerWorkflowOptions:
workflow_run_id = self.action.workflow_run_id
step_run_id = self.action.step_run_id

Expand All @@ -54,7 +63,7 @@ def _prepare_workflow_options(
if options is not None and "additional_metadata" in options:
meta = options["additional_metadata"]

trigger_options: TriggerWorkflowOptions = {
trigger_options: TriggerWorkflowOptions = { # type: ignore[typeddict-item]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is here because namespace isn't set - if we do set it, it causes problems downstream. I think Pydantic here would be helpful instead of using typed dicts

"parent_id": workflow_run_id,
"parent_step_run_id": step_run_id,
"child_key": key,
Expand Down Expand Up @@ -95,9 +104,9 @@ def __init__(
async def spawn_workflow(
self,
workflow_name: str,
input: dict = {},
key: str = None,
options: ChildTriggerWorkflowOptions = None,
input: dict[str, Any] = {},
key: str | None = None,
options: ChildTriggerWorkflowOptions | None = None,
) -> WorkflowRunRef:
worker_id = self.worker.id()
# if (
Expand All @@ -118,23 +127,24 @@ async def spawn_workflow(

@tenacity_retry
async def spawn_workflows(
self, child_workflow_runs: List[ChildWorkflowRunDict]
) -> List[WorkflowRunRef]:
self, child_workflow_runs: list[ChildWorkflowRunDict]
) -> list[WorkflowRunRef]:

if len(child_workflow_runs) == 0:
raise Exception("no child workflows to spawn")

worker_id = self.worker.id()

bulk_trigger_workflow_runs: WorkflowRunDict = []
bulk_trigger_workflow_runs: list[WorkflowRunDict] = []
for child_workflow_run in child_workflow_runs:
workflow_name = child_workflow_run["workflow_name"]
input = child_workflow_run["input"]

key = child_workflow_run.get("key")
options = child_workflow_run.get("options", {})

trigger_options = self._prepare_workflow_options(key, options, worker_id)
## TODO: figure out why this is failing
trigger_options = self._prepare_workflow_options(key, options, worker_id) # type: ignore[arg-type]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird type error here on the options that I'll figure out later


bulk_trigger_workflow_runs.append(
WorkflowRunDict(
Expand Down Expand Up @@ -179,11 +189,11 @@ def __init__(
# Check the type of action.action_payload before attempting to load it as JSON
if isinstance(action.action_payload, (str, bytes, bytearray)):
try:
self.data = json.loads(action.action_payload)
self.data = cast(dict[str, Any], json.loads(action.action_payload))
hatchet-temporary marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
logger.error(f"Error parsing action payload: {e}")
# Assign an empty dictionary if parsing fails
self.data = {}
self.data: dict[str, Any] = {} # type: ignore[no-redef]
else:
# Directly assign the payload to self.data if it's already a dict
self.data = (
Expand Down Expand Up @@ -218,33 +228,34 @@ def __init__(
else:
self.input = self.data.get("input", {})

def step_output(self, step: str):
def step_output(self, step: str) -> dict[str, Any]:
try:
return self.data["parents"][step]
return cast(dict[str, Any], self.data["parents"][step])
except KeyError:
raise ValueError(f"Step output for '{step}' not found")

def triggered_by_event(self) -> bool:
return self.data.get("triggered_by", "") == "event"
return cast(str, self.data.get("triggered_by", "")) == "event"

def workflow_input(self):
def workflow_input(self) -> dict[str, Any]:
return self.input

def workflow_run_id(self):
def workflow_run_id(self) -> str:
return self.action.workflow_run_id

def cancel(self):
def cancel(self) -> None:
logger.debug("cancelling step...")
self.exit_flag = True

# done returns true if the context has been cancelled
def done(self):
def done(self) -> bool:
return self.exit_flag

def playground(self, name: str, default: str = None):
def playground(self, name: str, default: str | None = None) -> str | None:
# if the key exists in the overrides_data field, return the value
if name in self.overrides_data:
return self.overrides_data[name]
## TODO: Check if this is the right type
return cast(str, self.overrides_data[name])
hatchet-temporary marked this conversation as resolved.
Show resolved Hide resolved

caller_file = get_caller_file_path()

Expand All @@ -259,15 +270,15 @@ def playground(self, name: str, default: str = None):

return default

def _log(self, line: str) -> (bool, Exception): # type: ignore
def _log(self, line: str) -> tuple[bool, Exception | None]:
try:
self.event_client.log(message=line, step_run_id=self.stepRunId)
return True, None
except Exception as e:
# we don't want to raise an exception here, as it will kill the log thread
return False, e

def log(self, line, raise_on_error: bool = False):
def log(self, line: Any, raise_on_error: bool = False) -> None:
if self.stepRunId == "":
return

Expand All @@ -277,9 +288,9 @@ def log(self, line, raise_on_error: bool = False):
except Exception:
line = str(line)

future: Future = self.logger_thread_pool.submit(self._log, line)
future = self.logger_thread_pool.submit(self._log, line)

def handle_result(future: Future):
def handle_result(future: Future[tuple[bool, Exception | None]]) -> None:
success, exception = future.result()
if not success and exception:
if raise_on_error:
Expand All @@ -297,51 +308,51 @@ def handle_result(future: Future):

future.add_done_callback(handle_result)

def release_slot(self):
def release_slot(self) -> None:
return self.dispatcher_client.release_slot(self.stepRunId)

def _put_stream(self, data: str | bytes):
def _put_stream(self, data: str | bytes) -> None:
try:
self.event_client.stream(data=data, step_run_id=self.stepRunId)
except Exception as e:
logger.error(f"Error putting stream event: {e}")

def put_stream(self, data: str | bytes):
def put_stream(self, data: str | bytes) -> None:
if self.stepRunId == "":
return

self.stream_event_thread_pool.submit(self._put_stream, data)

def refresh_timeout(self, increment_by: str):
def refresh_timeout(self, increment_by: str) -> None:
try:
return self.dispatcher_client.refresh_timeout(
step_run_id=self.stepRunId, increment_by=increment_by
)
except Exception as e:
logger.error(f"Error refreshing timeout: {e}")

def retry_count(self):
def retry_count(self) -> int:
return self.action.retry_count

def additional_metadata(self):
def additional_metadata(self) -> dict[str, Any] | None:
return self.action.additional_metadata

def child_index(self):
def child_index(self) -> int | None:
return self.action.child_workflow_index

def child_key(self):
def child_key(self) -> str | None:
return self.action.child_workflow_key

def parent_workflow_run_id(self):
def parent_workflow_run_id(self) -> str | None:
return self.action.parent_workflow_run_id

def fetch_run_failures(self):
def fetch_run_failures(self) -> list[dict[str, StrictStr]]:
data = self.rest_client.workflow_run_get(self.action.workflow_run_id)
other_job_runs = [
run for run in data.job_runs if run.job_id != self.action.job_id
run for run in (data.job_runs or []) if run.job_id != self.action.job_id
]
# TODO: Parse Step Runs using a Pydantic Model rather than a hand crafted dictionary
failed_step_runs = [
return [
{
"step_id": step_run.step_id,
"step_run_action_name": step_run.step.action,
Expand All @@ -350,7 +361,5 @@ def fetch_run_failures(self):
for job_run in other_job_runs
if job_run.step_runs
for step_run in job_run.step_runs
if step_run.error
if step_run.error and step_run.step
]

return failed_step_runs
2 changes: 1 addition & 1 deletion hatchet_sdk/context/worker_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def async_upsert_labels(self, labels: dict[str, str | int]):
await self.client.async_upsert_worker_labels(self._worker_id, labels)
self._labels.update(labels)

def id(self):
def id(self) -> str:
return self._worker_id

# def has_workflow(self, workflow_name: str):
Expand Down
Loading