Skip to content

Commit

Permalink
feat: Stream sync context is now available to all instances methods a…
Browse files Browse the repository at this point in the history
…s a `Stream.context` attribute
  • Loading branch information
edgarrmondragon committed Jul 10, 2024
1 parent 6256fe5 commit 500265e
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Iterable
from typing import TYPE_CHECKING, Iterable

import requests # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream
Expand All @@ -12,6 +12,9 @@
from {{ cookiecutter.library_name }}.auth import {{ cookiecutter.source_name }}Authenticator
{%- endif %}

if TYPE_CHECKING:
from singer_sdk.helpers.types import Context


class {{ cookiecutter.source_name }}Stream({{ cookiecutter.stream_type }}Stream):
"""{{ cookiecutter.source_name }} stream class."""
Expand Down Expand Up @@ -67,7 +70,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]:
def post_process(
self,
row: dict,
context: dict | None = None, # noqa: ARG002
context: Context | None = None, # noqa: ARG002
) -> dict | None:
"""As needed, append or transform raw data to match expected structure.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

from __future__ import annotations

from typing import Iterable
from typing import TYPE_CHECKING, Iterable

from singer_sdk.streams import Stream

if TYPE_CHECKING:
from singer_sdk.helpers.types import Context


class {{ cookiecutter.source_name }}Stream(Stream):
"""Stream class for {{ cookiecutter.source_name }} streams."""

def get_records(
self,
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
) -> Iterable[dict]:
"""Return a generator of record-type dictionary objects.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{%- if cookiecutter.auth_method in ("OAuth2", "JWT") %}
from functools import cached_property
{%- endif %}
from typing import Any, Callable, Iterable
from typing import TYPE_CHECKING, Any, Callable, Iterable

import requests
{% if cookiecutter.auth_method == "API Key" -%}
Expand Down Expand Up @@ -46,6 +46,10 @@
else:
import importlib_resources

if TYPE_CHECKING:
from singer_sdk.helpers.types import Context


_Auth = Callable[[requests.PreparedRequest], requests.PreparedRequest]

# TODO: Delete this is if not using json files for schema definition
Expand Down Expand Up @@ -157,7 +161,7 @@ def get_new_paginator(self) -> BaseAPIPaginator:

def get_url_params(
self,
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
next_page_token: Any | None, # noqa: ANN401
) -> dict[str, Any]:
"""Return a dictionary of values to be used in URL parameterization.
Expand All @@ -179,7 +183,7 @@ def get_url_params(

def prepare_request_payload(
self,
context: dict | None, # noqa: ARG002
context: Context | None, # noqa: ARG002
next_page_token: Any | None, # noqa: ARG002, ANN401
) -> dict | None:
"""Prepare the data payload for the REST API request.
Expand Down Expand Up @@ -211,7 +215,7 @@ def parse_response(self, response: requests.Response) -> Iterable[dict]:
def post_process(
self,
row: dict,
context: dict | None = None, # noqa: ARG002
context: Context | None = None, # noqa: ARG002
) -> dict | None:
"""As needed, append or transform raw data to match expected structure.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ omit = [
"tests/*",
"samples/*",
"singer_sdk/helpers/_compat.py",
"singer_sdk/helpers/types.py",
]

[tool.coverage.report]
Expand Down
12 changes: 7 additions & 5 deletions singer_sdk/helpers/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
if t.TYPE_CHECKING:
import datetime

from singer_sdk.helpers import types

_T = t.TypeVar("_T", datetime.datetime, str, int, float)

PROGRESS_MARKERS = "progress_markers"
Expand Down Expand Up @@ -70,7 +72,7 @@ def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict]

def _find_in_partitions_list(
partitions: list[dict],
state_partition_context: dict,
state_partition_context: types.Context,
) -> dict | None:
found = [
partition_state
Expand All @@ -88,7 +90,7 @@ def _find_in_partitions_list(

def _create_in_partitions_list(
partitions: list[dict],
state_partition_context: dict,
state_partition_context: types.Context,
) -> dict:
# Existing partition not found. Creating new state entry in partitions list...
new_partition_state = {"context": state_partition_context}
Expand All @@ -99,7 +101,7 @@ def _create_in_partitions_list(
def get_writeable_state_dict(
tap_state: dict,
tap_stream_id: str,
state_partition_context: dict | None = None,
state_partition_context: types.Context | None = None,
) -> dict:
"""Return the stream or partition state, creating a new one if it does not exist.
Expand Down Expand Up @@ -283,8 +285,8 @@ def log_sort_error(
ex: Exception,
log_fn: t.Callable,
stream_name: str,
current_context: dict | None,
state_partition_context: dict | None,
current_context: types.Context | None,
state_partition_context: types.Context | None,
record_count: int,
partition_record_count: int,
) -> None:
Expand Down
23 changes: 23 additions & 0 deletions singer_sdk/helpers/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Type aliases for use in the SDK."""

from __future__ import annotations

import sys
import typing as t

if sys.version_info < (3, 9):
from typing import Mapping # noqa: ICN003
else:
from collections.abc import Mapping

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias # noqa: ICN003

__all__ = [
"Context",
]

Context: TypeAlias = Mapping
Record: TypeAlias = t.Dict[str, t.Any]
6 changes: 4 additions & 2 deletions singer_sdk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
if t.TYPE_CHECKING:
from types import TracebackType

from singer_sdk.helpers import types
from singer_sdk.helpers._compat import Traversable


DEFAULT_LOG_INTERVAL = 60.0
METRICS_LOGGER_NAME = __name__
METRICS_LOG_LEVEL_SETTING = "metrics_log_level"
Expand Down Expand Up @@ -117,7 +119,7 @@ def __init__(self, metric: Metric, tags: dict | None = None) -> None:
self.logger = get_metrics_logger()

@property
def context(self) -> dict | None:
def context(self) -> types.Context | None:
"""Get the context for this meter.
Returns:
Expand All @@ -126,7 +128,7 @@ def context(self) -> dict | None:
return self.tags.get(Tag.CONTEXT)

@context.setter
def context(self, value: dict | None) -> None:
def context(self, value: types.Context | None) -> None:
"""Set the context for this meter.
Args:
Expand Down
Loading

0 comments on commit 500265e

Please sign in to comment.