Skip to content

Commit

Permalink
feat: SQLTarget connector instance shared with sinks (#1864)
Browse files Browse the repository at this point in the history
Co-authored-by: Edgar R. M <[email protected]>
  • Loading branch information
BuzzCutNorman and edgarrmondragon authored Jul 28, 2023
1 parent b1b3bd2 commit 759c77b
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 6 deletions.
133 changes: 131 additions & 2 deletions singer_sdk/target_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
if t.TYPE_CHECKING:
from pathlib import PurePath

from singer_sdk.connectors import SQLConnector
from singer_sdk.mapper import PluginMapper
from singer_sdk.sinks import Sink
from singer_sdk.sinks import Sink, SQLSink

_MAX_PARALLELISM = 8

Expand All @@ -48,7 +49,7 @@ class Target(PluginBase, SingerReader, metaclass=abc.ABCMeta):

# Default class to use for creating new sink objects.
# Required if `Target.get_sink_class()` is not defined.
default_sink_class: type[Sink] | None = None
default_sink_class: type[Sink]

def __init__(
self,
Expand Down Expand Up @@ -574,6 +575,23 @@ def get_singer_command(cls: type[Target]) -> click.Command:
class SQLTarget(Target):
"""Target implementation for SQL destinations."""

_target_connector: SQLConnector | None = None

default_sink_class: type[SQLSink]

@property
def target_connector(self) -> SQLConnector:
"""The connector object.
Returns:
The connector object.
"""
if self._target_connector is None:
self._target_connector = self.default_sink_class.connector_class(
dict(self.config),
)
return self._target_connector

@classproperty
def capabilities(self) -> list[CapabilitiesEnum]:
"""Get target capabilities.
Expand Down Expand Up @@ -617,3 +635,114 @@ def _merge_missing(source_jsonschema: dict, target_jsonschema: dict) -> None:
super().append_builtin_config(config_jsonschema)

pass

@final
def add_sqlsink(
self,
stream_name: str,
schema: dict,
key_properties: list[str] | None = None,
) -> Sink:
"""Create a sink and register it.
This method is internal to the SDK and should not need to be overridden.
Args:
stream_name: Name of the stream.
schema: Schema of the stream.
key_properties: Primary key of the stream.
Returns:
A new sink for the stream.
"""
self.logger.info("Initializing '%s' target sink...", self.name)
sink_class = self.get_sink_class(stream_name=stream_name)
sink = sink_class(
target=self,
stream_name=stream_name,
schema=schema,
key_properties=key_properties,
connector=self.target_connector,
)
sink.setup()
self._sinks_active[stream_name] = sink

return sink

def get_sink_class(self, stream_name: str) -> type[SQLSink]:
"""Get sink for a stream.
Developers can override this method to return a custom Sink type depending
on the value of `stream_name`. Optional when `default_sink_class` is set.
Args:
stream_name: Name of the stream.
Raises:
ValueError: If no :class:`singer_sdk.sinks.Sink` class is defined.
Returns:
The sink class to be used with the stream.
"""
if self.default_sink_class:
return self.default_sink_class

msg = (
f"No sink class defined for '{stream_name}' and no default sink class "
"available."
)
raise ValueError(msg)

def get_sink(
self,
stream_name: str,
*,
record: dict | None = None,
schema: dict | None = None,
key_properties: list[str] | None = None,
) -> Sink:
"""Return a sink for the given stream name.
A new sink will be created if `schema` is provided and if either `schema` or
`key_properties` has changed. If so, the old sink becomes archived and held
until the next drain_all() operation.
Developers only need to override this method if they want to provide a different
sink depending on the values within the `record` object. Otherwise, please see
`default_sink_class` property and/or the `get_sink_class()` method.
Raises :class:`singer_sdk.exceptions.RecordsWithoutSchemaException` if sink does
not exist and schema is not sent.
Args:
stream_name: Name of the stream.
record: Record being processed.
schema: Stream schema.
key_properties: Primary key of the stream.
Returns:
The sink used for this target.
"""
_ = record # Custom implementations may use record in sink selection.
if schema is None:
self._assert_sink_exists(stream_name)
return self._sinks_active[stream_name]

existing_sink = self._sinks_active.get(stream_name, None)
if not existing_sink:
return self.add_sqlsink(stream_name, schema, key_properties)

if (
existing_sink.schema != schema
or existing_sink.key_properties != key_properties
):
self.logger.info(
"Schema or key properties for '%s' stream have changed. "
"Initializing a new '%s' sink...",
stream_name,
stream_name,
)
self._sinks_to_clear.append(self._sinks_active.pop(stream_name))
return self.add_sqlsink(stream_name, schema, key_properties)

return existing_sink
65 changes: 63 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import pytest
from sqlalchemy import __version__ as sqlalchemy_version

from singer_sdk import SQLConnector
from singer_sdk import typing as th
from singer_sdk.sinks import BatchSink
from singer_sdk.target_base import Target
from singer_sdk.sinks import BatchSink, SQLSink
from singer_sdk.target_base import SQLTarget, Target

if t.TYPE_CHECKING:
from _pytest.config import Config
Expand Down Expand Up @@ -116,3 +117,63 @@ def _write_state_message(self, state: dict):
"""Emit the stream's latest state."""
super()._write_state_message(state)
self.state_messages_written.append(state)


class SQLConnectorMock(SQLConnector):
"""A Mock SQLConnector class."""


class SQLSinkMock(SQLSink):
"""A mock Sink class."""

name = "sql-sink-mock"
connector_class = SQLConnectorMock

def __init__(
self,
target: SQLTargetMock,
stream_name: str,
schema: dict,
key_properties: list[str] | None,
connector: SQLConnector | None = None,
):
"""Create the Mock batch-based sink."""
self._connector: SQLConnector
self._connector = connector or self.connector_class(dict(target.config))
super().__init__(target, stream_name, schema, key_properties, connector)
self.target = target

def process_record(self, record: dict, context: dict) -> None:
"""Tracks the count of processed records."""
self.target.num_records_processed += 1
super().process_record(record, context)

def process_batch(self, context: dict) -> None:
"""Write to mock trackers."""
self.target.records_written.extend(context["records"])
self.target.num_batches_processed += 1

@property
def key_properties(self) -> list[str]:
return [key.upper() for key in super().key_properties]


class SQLTargetMock(SQLTarget):
"""A mock Target class."""

name = "sql-target-mock"
config_jsonschema = th.PropertiesList().to_dict()
default_sink_class = SQLSinkMock

def __init__(self, *args, **kwargs):
"""Create the Mock target sync."""
super().__init__(*args, **kwargs)
self.state_messages_written: list[dict] = []
self.records_written: list[dict] = []
self.num_records_processed: int = 0
self.num_batches_processed: int = 0

def _write_state_message(self, state: dict):
"""Emit the stream's latest state."""
super()._write_state_message(state)
self.state_messages_written.append(state)
72 changes: 70 additions & 2 deletions tests/core/test_target_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

import pytest

from singer_sdk.exceptions import MissingKeyPropertiesError
from tests.conftest import BatchSinkMock, TargetMock
from singer_sdk.exceptions import (
MissingKeyPropertiesError,
RecordsWithoutSchemaException,
)
from tests.conftest import BatchSinkMock, SQLSinkMock, SQLTargetMock, TargetMock


def test_get_sink():
Expand Down Expand Up @@ -53,3 +56,68 @@ def test_validate_record():
# Test invalid record
with pytest.raises(MissingKeyPropertiesError):
sink._singer_validate_message({"name": "test"})


def test_sql_get_sink():
input_schema_1 = {
"properties": {
"id": {
"type": ["string", "null"],
},
"col_ts": {
"format": "date-time",
"type": ["string", "null"],
},
},
}
input_schema_2 = copy.deepcopy(input_schema_1)
key_properties = []
target = SQLTargetMock(config={"sqlalchemy_url": "sqlite:///"})
sink = SQLSinkMock(
target=target,
stream_name="foo",
schema=input_schema_1,
key_properties=key_properties,
connector=target.target_connector,
)
target._sinks_active["foo"] = sink
sink_returned = target.get_sink(
"foo",
schema=input_schema_2,
key_properties=key_properties,
)
assert sink_returned is sink


def test_add_sqlsink_and_get_sink():
input_schema_1 = {
"properties": {
"id": {
"type": ["string", "null"],
},
"col_ts": {
"format": "date-time",
"type": ["string", "null"],
},
},
}
input_schema_2 = copy.deepcopy(input_schema_1)
key_properties = []
target = SQLTargetMock(config={"sqlalchemy_url": "sqlite:///"})
sink = target.add_sqlsink(
"foo",
schema=input_schema_2,
key_properties=key_properties,
)

sink_returned = target.get_sink(
"foo",
)

assert sink_returned is sink

# Test invalid call
with pytest.raises(RecordsWithoutSchemaException):
target.get_sink(
"bar",
)

0 comments on commit 759c77b

Please sign in to comment.