diff --git a/singer_sdk/helpers/capabilities.py b/singer_sdk/helpers/capabilities.py index 3445c5bc6..f8c6bad9b 100644 --- a/singer_sdk/helpers/capabilities.py +++ b/singer_sdk/helpers/capabilities.py @@ -159,6 +159,13 @@ description="Maximum number of rows in each batch.", ), ).to_dict() +TAP_MAX_PARALLELISM_CONFIG = PropertiesList( + Property( + "max_parallelism", + IntegerType, + description="Max number of streams that can sync in parallel.", + ), +).to_dict() class TargetLoadMethods(str, Enum): diff --git a/singer_sdk/tap_base.py b/singer_sdk/tap_base.py index d69fa5f38..ae80d3cd5 100644 --- a/singer_sdk/tap_base.py +++ b/singer_sdk/tap_base.py @@ -4,10 +4,15 @@ import abc import contextlib +import logging +import sys import typing as t from enum import Enum +from logging.handlers import QueueHandler, QueueListener +from multiprocessing import Manager, Queue import click +from joblib import Parallel, delayed, parallel_config from singer_sdk._singerlib import Catalog, StateMessage from singer_sdk.configuration._dict_config import merge_missing_config_jsonschema @@ -22,6 +27,7 @@ from singer_sdk.helpers._util import dump_json, read_json_file from singer_sdk.helpers.capabilities import ( BATCH_CONFIG, + TAP_MAX_PARALLELISM_CONFIG, CapabilitiesEnum, PluginCapabilities, TapCapabilities, @@ -94,6 +100,7 @@ def __init__( self._input_catalog: Catalog | None = None self._state: dict[str, Stream] = {} self._catalog: Catalog | None = None # Tap's working catalog + self._max_parallelism: int | None = self.config.get("max_parallelism") # Process input catalog if isinstance(catalog, Catalog): @@ -178,6 +185,20 @@ def setup_mapper(self) -> None: super().setup_mapper() self.mapper.register_raw_streams_from_catalog(self.catalog) + @property + def max_parallelism(self) -> int: + """Get max parallel sinks. + + The default is None if not overridden. + + Returns: + Max number of streams that can be synced in parallel. + """ + if self._max_parallelism in {0, 1}: + self._max_parallelism = None + + return self._max_parallelism + @classproperty def capabilities(self) -> list[CapabilitiesEnum]: # noqa: PLR6301 """Get tap capabilities. @@ -216,6 +237,9 @@ def append_builtin_config(cls: type[PluginBase], config_jsonschema: dict) -> Non capabilities = cls.capabilities if PluginCapabilities.BATCH in capabilities: merge_missing_config_jsonschema(BATCH_CONFIG, config_jsonschema) + merge_missing_config_jsonschema( + TAP_MAX_PARALLELISM_CONFIG, config_jsonschema + ) # Connection and sync tests: @@ -440,6 +464,47 @@ def _set_compatible_replication_methods(self) -> None: # Sync methods + @t.final + def sync_one( + self, + stream: Stream, + log_level: logging.Logger | None = None, + log_queue: Queue | None = None, + ) -> None: + """Sync a single stream. + + Args: + stream: The stream that your would like to sync. + log_level: The logging level used by Tap.logger. + log_queue: Multiprocess Queue used by the listener. + + This is a link to a logging example for joblib. + https://github.com/joblib/joblib/issues/1017 + """ + if self.max_parallelism is not None and not self.logger.hasHandlers(): + queue_handler = QueueHandler(log_queue) + self.logger.addHandler(queue_handler) + self.logger.setLevel(log_level) + self.metrics_logger.addHandler(queue_handler) + self.metrics_logger.setLevel(log_level) + + if not stream.selected and not stream.has_selected_descendents: + self.logger.info("Skipping deselected stream '%s'.", stream.name) + return + + if stream.parent_stream_type: + self.logger.debug( + "Child stream '%s' is expected to be called " + "by parent stream '%s'. " + "Skipping direct invocation.", + type(stream).__name__, + stream.parent_stream_type.__name__, + ) + return + + stream.sync() + stream.finalize_state_progress_markers() + @t.final def sync_all(self) -> None: """Sync all streams.""" @@ -447,24 +512,37 @@ def sync_all(self) -> None: self._set_compatible_replication_methods() self.write_message(StateMessage(value=self.state)) - stream: Stream - for stream in self.streams.values(): - if not stream.selected and not stream.has_selected_descendents: - self.logger.info("Skipping deselected stream '%s'.", stream.name) - continue - - if stream.parent_stream_type: - self.logger.debug( - "Child stream '%s' is expected to be called " - "by parent stream '%s'. " - "Skipping direct invocation.", - type(stream).__name__, - stream.parent_stream_type.__name__, + if self.max_parallelism is None: + stream: Stream + for stream in self.streams.values(): + self.sync_one(stream=stream) + else: + with Manager() as manager: + # Prepare logger for parallel processes + console_handler = logging.StreamHandler(sys.stderr) + console_formatter = logging.Formatter( + fmt="{asctime:23s} | {levelname:8s} | {name:20s} | {message}", + style="{", ) - continue - - stream.sync() - stream.finalize_state_progress_markers() + console_handler.setFormatter(console_formatter) + self.logger.addHandler(console_handler) + log_queue = manager.Queue() + listener = QueueListener(log_queue, *self.logger.handlers) + listener.start() + with parallel_config( + backend="loky", + prefer="processes", + n_jobs=self.max_parallelism, + ), Parallel() as parallel: + parallel( + delayed(self.sync_one)( + stream, + log_queue=log_queue, + log_level=self.logger.getEffectiveLevel(), + ) + for stream in self.streams.values() + ) + listener.stop() # this second loop is needed for all streams to print out their costs # including child streams which are otherwise skipped in the loop above