diff --git a/pyproject.toml b/pyproject.toml index 30ab036..930ad52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tap-mongodb" -version = "2.4.0" +version = "2.5.0" description = "`tap-mongodb` is a Singer tap for MongoDB and AWS DocumentDB, built with the Meltano Singer SDK." readme = "README.md" authors = ["Matt Menzenski"] diff --git a/tap_mongodb/streams.py b/tap_mongodb/streams.py index 21d0bf0..66584cd 100644 --- a/tap_mongodb/streams.py +++ b/tap_mongodb/streams.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math from datetime import datetime from typing import Any, Generator, Iterable, Optional @@ -18,12 +19,7 @@ from singer_sdk.helpers._state import increment_state from singer_sdk.helpers._typing import conform_record_data_types from singer_sdk.helpers._util import utc_now -from singer_sdk.streams.core import ( - REPLICATION_INCREMENTAL, - REPLICATION_LOG_BASED, - Stream, - TypeConformanceLevel, -) +from singer_sdk.streams.core import REPLICATION_INCREMENTAL, REPLICATION_LOG_BASED, Stream, TypeConformanceLevel from tap_mongodb.connector import MongoDBConnector from tap_mongodb.types import IncrementalId @@ -31,9 +27,29 @@ DEFAULT_START_DATE: str = "1970-01-01" +def recursive_replace_empty_in_dict(dct): + """ + Recursively replace empty values with None in a dictionary. + NaN, inf, and -inf are unable to be parsed by the json library, so these values will be replaced with None. + """ + for key, value in dct.items(): + if value in [-math.inf, math.inf, math.nan]: + dct[key] = None + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, dict): + recursive_replace_empty_in_dict(item) + elif item in [-math.inf, math.inf, math.nan]: + value[i] = None + elif isinstance(value, dict): + recursive_replace_empty_in_dict(value) + return dct + + def to_object_id(replication_key_value: str) -> ObjectId: """Converts an ISO-8601 date string into a BSON ObjectId.""" incremental_id: IncrementalId = IncrementalId.from_string(replication_key_value) + return incremental_id.object_id @@ -86,6 +102,7 @@ def primary_keys(self) -> Optional[list[str]]: def primary_keys(self, new_value: list[str]) -> None: """Set primary keys for the stream.""" self._primary_keys = new_value + return self @property def is_sorted(self) -> bool: @@ -95,6 +112,7 @@ def is_sorted(self) -> bool: string, and these are alphanumerically sortable. When the tap is running in log-based mode, it is not sorted - the replication key value here is a hex string.""" + return self.replication_method == REPLICATION_INCREMENTAL def _increment_stream_state(self, latest_record: dict[str, Any], *, context: dict | None = None) -> None: @@ -109,7 +127,9 @@ def _increment_stream_state(self, latest_record: dict[str, Any], *, context: dic Raises: ValueError: if configured replication method is unsupported, or if replication key is absent + """ + # This also creates a state entry if one does not yet exist: state_dict = self.get_context_state(context) @@ -142,6 +162,8 @@ def _increment_stream_state(self, latest_record: dict[str, Any], *, context: dic check_sorted=self.check_sorted, ) + return self + def _generate_record_messages(self, record: dict) -> Generator[singer.RecordMessage, None, None]: """Write out a RECORD message. @@ -165,26 +187,21 @@ def _generate_record_messages(self, record: dict) -> Generator[singer.RecordMess level=self.TYPE_CONFORMANCE_LEVEL, logger=self.logger, ) + for stream_map in self.stream_maps: mapped_record = stream_map.transform(record) # Emit record if not filtered if mapped_record is not None: record_message = singer.RecordMessage( - stream=stream_map.stream_alias, - record=mapped_record, - version=None, - time_extracted=extracted_at, + stream=stream_map.stream_alias, record=mapped_record, version=None, time_extracted=extracted_at ) - yield record_message def get_records(self, context: dict | None) -> Iterable[dict]: """Return a generator of record-type dictionary objects.""" # pylint: disable=too-many-locals,too-many-branches,too-many-statements bookmark: str = self.get_starting_replication_key_value(context) - should_add_metadata: bool = self.config.get("add_record_metadata", False) - collection: Collection = self._connector.database[self._collection_name] if self.replication_method == REPLICATION_INCREMENTAL: @@ -199,6 +216,9 @@ def get_records(self, context: dict | None) -> Iterable[dict]: for record in collection.find({"_id": {"$gt": start_date}}).sort([("_id", ASCENDING)]): object_id: ObjectId = record["_id"] incremental_id: IncrementalId = IncrementalId.from_object_id(object_id) + + recursive_replace_empty_in_dict(record) + parsed_record = { "replication_key": str(incremental_id), "object_id": str(object_id), @@ -261,6 +281,7 @@ def get_records(self, context: dict | None) -> Iterable[dict]: else: self.logger.critical(f"operation_failure on collection.watch: {operation_failure}") raise operation_failure + except Exception as exception: self.logger.critical(exception) raise exception