Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tap): Utilize Joblib to run parallel streams during sync_all #2295

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions singer_sdk/helpers/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@
description="Maximum number of rows in each batch.",
),
).to_dict()
TAP_MAX_PARALLELISM_CONFIG = PropertiesList(
Property(
"max_parallelism",
IntegerType,
description="Max number of streams that can sync in parallel.",
),
).to_dict()


class TargetLoadMethods(str, Enum):
Expand Down
112 changes: 95 additions & 17 deletions singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

import abc
import contextlib
import logging
import sys
import typing as t
from enum import Enum
from logging.handlers import QueueHandler, QueueListener
from multiprocessing import Manager, Queue

import click
from joblib import Parallel, delayed, parallel_config

from singer_sdk._singerlib import Catalog, StateMessage
from singer_sdk.configuration._dict_config import merge_missing_config_jsonschema
Expand All @@ -22,6 +27,7 @@
from singer_sdk.helpers._util import dump_json, read_json_file
from singer_sdk.helpers.capabilities import (
BATCH_CONFIG,
TAP_MAX_PARALLELISM_CONFIG,
CapabilitiesEnum,
PluginCapabilities,
TapCapabilities,
Expand Down Expand Up @@ -94,6 +100,7 @@ def __init__(
self._input_catalog: Catalog | None = None
self._state: dict[str, Stream] = {}
self._catalog: Catalog | None = None # Tap's working catalog
self._max_parallelism: int | None = self.config.get("max_parallelism")

# Process input catalog
if isinstance(catalog, Catalog):
Expand Down Expand Up @@ -178,6 +185,20 @@ def setup_mapper(self) -> None:
super().setup_mapper()
self.mapper.register_raw_streams_from_catalog(self.catalog)

@property
def max_parallelism(self) -> int:
"""Get max parallel sinks.

The default is None if not overridden.

Returns:
Max number of streams that can be synced in parallel.
"""
if self._max_parallelism in {0, 1}:
self._max_parallelism = None

return self._max_parallelism

@classproperty
def capabilities(self) -> list[CapabilitiesEnum]: # noqa: PLR6301
"""Get tap capabilities.
Expand Down Expand Up @@ -216,6 +237,9 @@ def append_builtin_config(cls: type[PluginBase], config_jsonschema: dict) -> Non
capabilities = cls.capabilities
if PluginCapabilities.BATCH in capabilities:
merge_missing_config_jsonschema(BATCH_CONFIG, config_jsonschema)
merge_missing_config_jsonschema(
TAP_MAX_PARALLELISM_CONFIG, config_jsonschema
)

# Connection and sync tests:

Expand Down Expand Up @@ -440,31 +464,85 @@ def _set_compatible_replication_methods(self) -> None:

# Sync methods

@t.final
def sync_one(
self,
stream: Stream,
log_level: logging.Logger | None = None,
log_queue: Queue | None = None,
) -> None:
"""Sync a single stream.

Args:
stream: The stream that your would like to sync.
log_level: The logging level used by Tap.logger.
log_queue: Multiprocess Queue used by the listener.

This is a link to a logging example for joblib.
https://github.com/joblib/joblib/issues/1017
"""
if self.max_parallelism is not None and not self.logger.hasHandlers():
queue_handler = QueueHandler(log_queue)
self.logger.addHandler(queue_handler)
self.logger.setLevel(log_level)
self.metrics_logger.addHandler(queue_handler)
self.metrics_logger.setLevel(log_level)

if not stream.selected and not stream.has_selected_descendents:
self.logger.info("Skipping deselected stream '%s'.", stream.name)
return

if stream.parent_stream_type:
self.logger.debug(
"Child stream '%s' is expected to be called "
"by parent stream '%s'. "
"Skipping direct invocation.",
type(stream).__name__,
stream.parent_stream_type.__name__,
)
return

stream.sync()
stream.finalize_state_progress_markers()

@t.final
def sync_all(self) -> None:
"""Sync all streams."""
self._reset_state_progress_markers()
self._set_compatible_replication_methods()
self.write_message(StateMessage(value=self.state))

stream: Stream
for stream in self.streams.values():
if not stream.selected and not stream.has_selected_descendents:
self.logger.info("Skipping deselected stream '%s'.", stream.name)
continue

if stream.parent_stream_type:
self.logger.debug(
"Child stream '%s' is expected to be called "
"by parent stream '%s'. "
"Skipping direct invocation.",
type(stream).__name__,
stream.parent_stream_type.__name__,
if self.max_parallelism is None:
stream: Stream
for stream in self.streams.values():
self.sync_one(stream=stream)
else:
with Manager() as manager:
# Prepare logger for parallel processes
console_handler = logging.StreamHandler(sys.stderr)
console_formatter = logging.Formatter(
fmt="{asctime:23s} | {levelname:8s} | {name:20s} | {message}",
style="{",
)
continue

stream.sync()
stream.finalize_state_progress_markers()
console_handler.setFormatter(console_formatter)
self.logger.addHandler(console_handler)
log_queue = manager.Queue()
listener = QueueListener(log_queue, *self.logger.handlers)
listener.start()
with parallel_config(
backend="loky",
prefer="processes",
n_jobs=self.max_parallelism,
), Parallel() as parallel:
parallel(
delayed(self.sync_one)(
stream,
log_queue=log_queue,
log_level=self.logger.getEffectiveLevel(),
)
for stream in self.streams.values()
)
listener.stop()

# this second loop is needed for all streams to print out their costs
# including child streams which are otherwise skipped in the loop above
Expand Down
Loading