Skip to content

Commit

Permalink
Merge pull request #24 from melgazar9/main
Browse files Browse the repository at this point in the history
  • Loading branch information
menzenski authored Oct 20, 2023
2 parents d716052 + ed108d6 commit b84a1e0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tap-mongodb"
version = "2.4.0"
version = "2.5.0"
description = "`tap-mongodb` is a Singer tap for MongoDB and AWS DocumentDB, built with the Meltano Singer SDK."
readme = "README.md"
authors = ["Matt Menzenski"]
Expand Down
47 changes: 34 additions & 13 deletions tap_mongodb/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import math
from datetime import datetime
from typing import Any, Generator, Iterable, Optional

Expand All @@ -18,22 +19,37 @@
from singer_sdk.helpers._state import increment_state
from singer_sdk.helpers._typing import conform_record_data_types
from singer_sdk.helpers._util import utc_now
from singer_sdk.streams.core import (
REPLICATION_INCREMENTAL,
REPLICATION_LOG_BASED,
Stream,
TypeConformanceLevel,
)
from singer_sdk.streams.core import REPLICATION_INCREMENTAL, REPLICATION_LOG_BASED, Stream, TypeConformanceLevel

from tap_mongodb.connector import MongoDBConnector
from tap_mongodb.types import IncrementalId

DEFAULT_START_DATE: str = "1970-01-01"


def recursive_replace_empty_in_dict(dct):
"""
Recursively replace empty values with None in a dictionary.
NaN, inf, and -inf are unable to be parsed by the json library, so these values will be replaced with None.
"""
for key, value in dct.items():
if value in [-math.inf, math.inf, math.nan]:
dct[key] = None
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
recursive_replace_empty_in_dict(item)
elif item in [-math.inf, math.inf, math.nan]:
value[i] = None
elif isinstance(value, dict):
recursive_replace_empty_in_dict(value)
return dct


def to_object_id(replication_key_value: str) -> ObjectId:
"""Converts an ISO-8601 date string into a BSON ObjectId."""
incremental_id: IncrementalId = IncrementalId.from_string(replication_key_value)

return incremental_id.object_id


Expand Down Expand Up @@ -86,6 +102,7 @@ def primary_keys(self) -> Optional[list[str]]:
def primary_keys(self, new_value: list[str]) -> None:
"""Set primary keys for the stream."""
self._primary_keys = new_value
return self

@property
def is_sorted(self) -> bool:
Expand All @@ -95,6 +112,7 @@ def is_sorted(self) -> bool:
string, and these are alphanumerically sortable.
When the tap is running in log-based mode, it is not sorted - the replication key value here is a hex string."""

return self.replication_method == REPLICATION_INCREMENTAL

def _increment_stream_state(self, latest_record: dict[str, Any], *, context: dict | None = None) -> None:
Expand All @@ -109,7 +127,9 @@ def _increment_stream_state(self, latest_record: dict[str, Any], *, context: dic
Raises:
ValueError: if configured replication method is unsupported, or if replication key is absent
"""

# This also creates a state entry if one does not yet exist:
state_dict = self.get_context_state(context)

Expand Down Expand Up @@ -142,6 +162,8 @@ def _increment_stream_state(self, latest_record: dict[str, Any], *, context: dic
check_sorted=self.check_sorted,
)

return self

def _generate_record_messages(self, record: dict) -> Generator[singer.RecordMessage, None, None]:
"""Write out a RECORD message.
Expand All @@ -165,26 +187,21 @@ def _generate_record_messages(self, record: dict) -> Generator[singer.RecordMess
level=self.TYPE_CONFORMANCE_LEVEL,
logger=self.logger,
)

for stream_map in self.stream_maps:
mapped_record = stream_map.transform(record)
# Emit record if not filtered
if mapped_record is not None:
record_message = singer.RecordMessage(
stream=stream_map.stream_alias,
record=mapped_record,
version=None,
time_extracted=extracted_at,
stream=stream_map.stream_alias, record=mapped_record, version=None, time_extracted=extracted_at
)

yield record_message

def get_records(self, context: dict | None) -> Iterable[dict]:
"""Return a generator of record-type dictionary objects."""
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
bookmark: str = self.get_starting_replication_key_value(context)

should_add_metadata: bool = self.config.get("add_record_metadata", False)

collection: Collection = self._connector.database[self._collection_name]

if self.replication_method == REPLICATION_INCREMENTAL:
Expand All @@ -199,6 +216,9 @@ def get_records(self, context: dict | None) -> Iterable[dict]:
for record in collection.find({"_id": {"$gt": start_date}}).sort([("_id", ASCENDING)]):
object_id: ObjectId = record["_id"]
incremental_id: IncrementalId = IncrementalId.from_object_id(object_id)

recursive_replace_empty_in_dict(record)

parsed_record = {
"replication_key": str(incremental_id),
"object_id": str(object_id),
Expand Down Expand Up @@ -261,6 +281,7 @@ def get_records(self, context: dict | None) -> Iterable[dict]:
else:
self.logger.critical(f"operation_failure on collection.watch: {operation_failure}")
raise operation_failure

except Exception as exception:
self.logger.critical(exception)
raise exception
Expand Down

0 comments on commit b84a1e0

Please sign in to comment.