From 74cce32551a1503977fc97933ec430ebb1a42f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= Date: Thu, 21 Dec 2023 20:54:13 -0600 Subject: [PATCH] refactor: Use a `TypedDict` to annotate state dictionaries --- singer_sdk/helpers/_state.py | 50 +++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/singer_sdk/helpers/_state.py b/singer_sdk/helpers/_state.py index f3a8311082..7f714c5c14 100644 --- a/singer_sdk/helpers/_state.py +++ b/singer_sdk/helpers/_state.py @@ -6,6 +6,7 @@ import typing as t from singer_sdk.exceptions import InvalidStreamSortException +from singer_sdk.helpers._compat import TypedDict from singer_sdk.helpers._typing import to_json_compatible if t.TYPE_CHECKING: @@ -21,12 +22,28 @@ logger = logging.getLogger("singer_sdk") +# class StreamStateDict(TypedDict, total=False): +# context: dict[str, t.Any] + +StreamStateDict = t.Dict[str, t.Any] + + +class PartitionsStateDict(TypedDict, total=False): + partitions: list[StreamStateDict] + + +class TapStateDict(TypedDict, total=False): + """State dictionary type.""" + + bookmarks: dict[str, StreamStateDict | PartitionsStateDict] + + def get_state_if_exists( - tap_state: dict, + tap_state: TapStateDict, tap_stream_id: str, - state_partition_context: dict | None = None, + state_partition_context: dict[str, t.Any] | None = None, key: str | None = None, -) -> t.Any | None: # noqa: ANN401 +) -> StreamStateDict | None: """Return the stream or partition state, creating a new one if it does not exist. Args: @@ -44,34 +61,49 @@ def get_state_if_exists( ValueError: Raised if state is invalid or cannot be parsed. """ if "bookmarks" not in tap_state: + # Not a valid state, e.g. {} return None + if tap_stream_id not in tap_state["bookmarks"]: + # Stream not present in state, e.g. {"bookmarks": {}} return None + # At this point state looks like {"bookmarks": {"my_stream": {"key": "value""}}} + + # stream_state: {"key": "value", "partitions"?: ...} stream_state = tap_state["bookmarks"][tap_stream_id] if not state_partition_context: - return stream_state.get(key, None) if key else stream_state + # Either 'value' if key is specified, or {} + return stream_state.get(key, None) if key else stream_state # type: ignore[return-value] + if "partitions" not in stream_state: return None # No partitions defined + # stream_state: {"partitions": [{"context": {"key": "value"}}]} # noqa: ERA001 + matched_partition = _find_in_partitions_list( stream_state["partitions"], state_partition_context, ) + if matched_partition is None: return None # Partition definition not present + return matched_partition.get(key, None) if key else matched_partition -def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] | None: +def get_state_partitions_list( + tap_state: TapStateDict, + tap_stream_id: str, +) -> list[StreamStateDict] | None: """Return a list of partitions defined in the state, or None if not defined.""" - return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None) # type: ignore[no-any-return] + return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None) def _find_in_partitions_list( - partitions: list[dict], - state_partition_context: dict, -) -> dict | None: + partitions: list[StreamStateDict], + state_partition_context: dict[str, t.Any], +) -> StreamStateDict | None: found = [ partition_state for partition_state in partitions