Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pydantic Input Support #151

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/pydantic/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dotenv import load_dotenv
from pydantic import BaseModel

from hatchet_sdk import PushEventOptions, new_client


class ClientPushPayload(BaseModel):
"""Example Pydantic model."""

test: str


load_dotenv()

client = new_client()
options_model = PushEventOptions(additional_metadata={"hello": "moon"})
payload_model = ClientPushPayload(test="test")

# client.event.push("user:create", {"test": "test"})
client.event.push(
"user:create",
payload_model,
options=options_model,
)
36 changes: 36 additions & 0 deletions examples/pydantic/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import time

from dotenv import load_dotenv
from pydantic import BaseModel

from hatchet_sdk import Context, Hatchet

load_dotenv()

hatchet = Hatchet(debug=True)


class WorkflowPayload(BaseModel):
"""Example Pydantic model."""

step1: str = "step1"


@hatchet.workflow(on_events=["user:create"])
class MyWorkflow:
@hatchet.step(timeout="15s", retries=3)
def step1(self, context: Context):
print("executed step1")
time.sleep(5)
return WorkflowPayload(step1="step1")


def main():
workflow = MyWorkflow()
worker = hatchet.worker("test-worker-pydantic", max_runs=1)
worker.register_workflow(workflow)
worker.start()


if __name__ == "__main__":
main()
32 changes: 23 additions & 9 deletions hatchet_sdk/clients/events.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime
import json
from typing import Dict, TypedDict
from typing import Optional

import grpc
from google.protobuf import timestamp_pb2
from pydantic import BaseModel

from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
from hatchet_sdk.contracts.events_pb2 import (
Expand Down Expand Up @@ -33,8 +34,8 @@ def proto_timestamp_now():
return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)


class PushEventOptions(TypedDict):
additional_metadata: Dict[str, str] | None = None
class PushEventOptions(BaseModel):
additional_metadata: dict[str, str] | None = None


class EventClient:
Expand All @@ -44,20 +45,33 @@ def __init__(self, client: EventsServiceStub, config: ClientConfig):
self.namespace = config.namespace

@tenacity_retry
def push(self, event_key, payload, options: PushEventOptions = None) -> Event:
def push(
self,
event_key,
payload: dict | BaseModel,
options: Optional[dict | PushEventOptions] = {},
) -> Event:

namespaced_event_key = self.namespace + event_key
if isinstance(options, dict):
options = PushEventOptions(**options)

try:
meta = None if options is None else options["additional_metadata"]
meta = options.additional_metadata
meta_bytes = None if meta is None else json.dumps(meta).encode("utf-8")
except Exception as e:
raise ValueError(f"Error encoding meta: {e}")

try:
payload_bytes = json.dumps(payload).encode("utf-8")
except json.UnicodeEncodeError as e:
raise ValueError(f"Error encoding payload: {e}")
if isinstance(payload, BaseModel):
try:
payload_bytes = payload.model_dump_json().encode("utf-8")
except json.UnicodeEncodeError as e:
raise ValueError(f"Error encoding Pydantic model: {e}")
else:
try:
payload_bytes = json.dumps(payload).encode("utf-8")
except json.UnicodeEncodeError as e:
raise ValueError(f"Error encoding dict payload: {e}")

request = PushEventRequest(
key=namespaced_event_key,
Expand Down
Loading