Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Use a TypedDict to annotate state dictionaries #2419

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions singer_sdk/helpers/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from __future__ import annotations

import logging
import sys
import typing as t

from singer_sdk.exceptions import InvalidStreamSortException
from singer_sdk.helpers._typing import to_json_compatible

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias # noqa: ICN003

if t.TYPE_CHECKING:
import datetime

Expand All @@ -21,14 +27,25 @@
STARTING_MARKER = "starting_replication_value"

logger = logging.getLogger("singer_sdk")
StreamStateDict: TypeAlias = t.Dict[str, t.Any]


class PartitionsStateDict(t.TypedDict, total=False):
partitions: list[StreamStateDict]


class TapStateDict(t.TypedDict, total=False):
"""State dictionary type."""

bookmarks: dict[str, StreamStateDict | PartitionsStateDict]


def get_state_if_exists(
tap_state: dict,
tap_state: TapStateDict,
tap_stream_id: str,
state_partition_context: dict | None = None,
state_partition_context: dict[str, t.Any] | None = None,
key: str | None = None,
) -> t.Any | None: # noqa: ANN401
) -> StreamStateDict | None:
"""Return the stream or partition state, creating a new one if it does not exist.

Args:
Expand All @@ -46,34 +63,49 @@ def get_state_if_exists(
ValueError: Raised if state is invalid or cannot be parsed.
"""
if "bookmarks" not in tap_state:
# Not a valid state, e.g. {}
return None

if tap_stream_id not in tap_state["bookmarks"]:
# Stream not present in state, e.g. {"bookmarks": {}}
return None

# At this point state looks like {"bookmarks": {"my_stream": {"key": "value""}}}

# stream_state: {"key": "value", "partitions"?: ...}
stream_state = tap_state["bookmarks"][tap_stream_id]
if not state_partition_context:
return stream_state.get(key, None) if key else stream_state
# Either 'value' if key is specified, or {}
return stream_state.get(key, None) if key else stream_state # type: ignore[return-value]

if "partitions" not in stream_state:
return None # No partitions defined

# stream_state: {"partitions": [{"context": {"key": "value"}}]} # noqa: ERA001

matched_partition = _find_in_partitions_list(
stream_state["partitions"],
state_partition_context,
)

if matched_partition is None:
return None # Partition definition not present

return matched_partition.get(key, None) if key else matched_partition


def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] | None:
def get_state_partitions_list(
tap_state: TapStateDict,
tap_stream_id: str,
) -> list[StreamStateDict] | None:
"""Return a list of partitions defined in the state, or None if not defined."""
return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None) # type: ignore[no-any-return]


def _find_in_partitions_list(
partitions: list[dict],
partitions: list[StreamStateDict],
state_partition_context: types.Context,
) -> dict | None:
) -> StreamStateDict | None:
found = [
partition_state
for partition_state in partitions
Expand All @@ -99,10 +131,10 @@ def _create_in_partitions_list(


def get_writeable_state_dict(
tap_state: dict,
tap_state: TapStateDict,
tap_stream_id: str,
state_partition_context: types.Context | None = None,
) -> dict:
) -> StreamStateDict:
"""Return the stream or partition state, creating a new one if it does not exist.

Args:
Expand All @@ -125,13 +157,13 @@ def get_writeable_state_dict(
tap_state["bookmarks"] = {}
if tap_stream_id not in tap_state["bookmarks"]:
tap_state["bookmarks"][tap_stream_id] = {}
stream_state = t.cast(dict, tap_state["bookmarks"][tap_stream_id])
stream_state = tap_state["bookmarks"][tap_stream_id]
if not state_partition_context:
return stream_state
return stream_state # type: ignore[return-value]

if "partitions" not in stream_state:
stream_state["partitions"] = []
stream_state_partitions: list[dict] = stream_state["partitions"]
stream_state_partitions: list[StreamStateDict] = stream_state["partitions"]
if found := _find_in_partitions_list(
stream_state_partitions,
state_partition_context,
Expand All @@ -142,7 +174,7 @@ def get_writeable_state_dict(


def write_stream_state(
tap_state: dict,
tap_state: TapStateDict,
tap_stream_id: str,
key: str,
val: t.Any, # noqa: ANN401
Expand All @@ -158,12 +190,14 @@ def write_stream_state(
state_dict[key] = val


def reset_state_progress_markers(stream_or_partition_state: dict) -> dict | None:
def reset_state_progress_markers(
stream_or_partition_state: StreamStateDict | PartitionsStateDict,
) -> dict | None:
"""Wipe the state once sync is complete.

For logging purposes, return the wiped 'progress_markers' object if it existed.
"""
progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {})
progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {}) # type: ignore[misc]
# Remove auto-generated human-readable note:
progress_markers.pop(PROGRESS_MARKER_NOTE, None)
# Return remaining 'progress_markers' if any:
Expand Down
7 changes: 4 additions & 3 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

from singer_sdk.helpers import types
from singer_sdk.helpers._compat import Traversable
from singer_sdk.helpers._state import TapStateDict
from singer_sdk.tap_base import Tap

# Replication methods
Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(
self._mask: singer.SelectionMask | None = None
self._schema: dict
self._is_state_flushed: bool = True
self._last_emitted_state: dict | None = None
self._last_emitted_state: TapStateDict | None = None
self._sync_costs: dict[str, int] = {}
self.child_streams: list[Stream] = []
if schema:
Expand Down Expand Up @@ -645,7 +646,7 @@ def replication_method(self) -> str:
# State properties:

@property
def tap_state(self) -> dict:
def tap_state(self) -> TapStateDict:
"""Return a writeable state dict for the entire tap.

Note: This dictionary is shared (and writable) across all streams.
Expand Down Expand Up @@ -790,7 +791,7 @@ def _write_state_message(self) -> None:
if (not self._is_state_flushed) and (
self.tap_state != self._last_emitted_state
):
self._tap.write_message(singer.StateMessage(value=self.tap_state))
self._tap.write_message(singer.StateMessage(value=self.tap_state)) # type: ignore[arg-type]
self._last_emitted_state = copy.deepcopy(self.tap_state)
self._is_state_flushed = True

Expand Down
7 changes: 4 additions & 3 deletions singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pathlib import PurePath

from singer_sdk.connectors import SQLConnector
from singer_sdk.helpers._state import TapStateDict
from singer_sdk.mapper import PluginMapper
from singer_sdk.streams import SQLStream, Stream

Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
# Declare private members
self._streams: dict[str, Stream] | None = None
self._input_catalog: Catalog | None = None
self._state: dict[str, Stream] = {}
self._state: TapStateDict = {}
self._catalog: Catalog | None = None # Tap's working catalog

# Process input catalog
Expand Down Expand Up @@ -138,7 +139,7 @@ def streams(self) -> dict[str, Stream]:
return self._streams

@property
def state(self) -> dict:
def state(self) -> TapStateDict: # type: ignore[override]
"""Get tap state.

Returns:
Expand Down Expand Up @@ -445,7 +446,7 @@ def sync_all(self) -> None:
"""Sync all streams."""
self._reset_state_progress_markers()
self._set_compatible_replication_methods()
self.write_message(StateMessage(value=self.state))
self.write_message(StateMessage(value=self.state)) # type: ignore[arg-type]

stream: Stream
for stream in self.streams.values():
Expand Down
Loading