Skip to content

Commit

Permalink
various lint and typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
menzenski committed Apr 23, 2024
1 parent d42a173 commit 0207c76
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions tap_mongodb/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

import math
from datetime import datetime
from typing import Any, Dict, Generator, Iterable, Optional, Union
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Iterable, List, Optional, Union

from bson.objectid import ObjectId
from loguru import logger
Expand Down Expand Up @@ -37,13 +37,13 @@ def recursive_replace_empty_in_dict(dct: Dict) -> Dict:
for key, value in dct.items():
if value in [-math.inf, math.inf, math.nan]:
dct[key] = None
elif isinstance(value, list):
elif isinstance(value, List):
for i, item in enumerate(value):
if isinstance(item, dict):
if isinstance(item, Dict):
recursive_replace_empty_in_dict(item)
elif item in [-math.inf, math.inf, math.nan]:
value[i] = None
elif isinstance(value, dict):
elif isinstance(value, Dict):
recursive_replace_empty_in_dict(value)
return dct

Expand All @@ -64,7 +64,7 @@ class MongoDBCollectionStream(Stream):
def __init__(
self,
tap: TapBaseClass,
catalog_entry: dict,
catalog_entry: Dict,
connector: MongoDBConnector,
) -> None:
"""Initialize the database stream.
Expand All @@ -86,15 +86,15 @@ def __init__(
)

@property
def primary_keys(self) -> Optional[list[str]]:
def primary_keys(self) -> Optional[List[str]]:
"""If running in log-based replication mode, use the Change Event ID as the primary key. If running instead in
incremental replication mode, use the document's ObjectId."""
if self.replication_method == REPLICATION_LOG_BASED:
return ["replication_key"]
return ["object_id"]

@primary_keys.setter
def primary_keys(self, new_value: list[str]) -> None:
def primary_keys(self, new_value: List[str]) -> None:
"""Set primary keys for the stream."""
self._primary_keys = new_value

Expand All @@ -109,7 +109,7 @@ def is_sorted(self) -> bool:

return self.replication_method == REPLICATION_INCREMENTAL

def _increment_stream_state(self, latest_record: dict[str, Any], *, context: dict | None = None) -> None:
def _increment_stream_state(self, latest_record: Dict[str, Any], *, context: Optional[Dict] = None) -> None:
"""Update state of stream or partition with data from the provided record.
Raises `InvalidStreamSortException` is `self.is_sorted = True` and unsorted data
Expand Down Expand Up @@ -157,7 +157,7 @@ def _increment_stream_state(self, latest_record: dict[str, Any], *, context: dic
check_sorted=self.check_sorted,
)

def _generate_record_messages(self, record: dict) -> Generator[singer.RecordMessage, None, None]:
def _generate_record_messages(self, record: Dict) -> Generator[singer.RecordMessage, None, None]:
"""Write out a RECORD message.
We are overriding the default implementation of this (private) method because the default behavior is to set
Expand Down Expand Up @@ -192,7 +192,7 @@ def _generate_record_messages(self, record: dict) -> Generator[singer.RecordMess

def _get_records_incremental(
self, bookmark: str, should_add_metadata: bool, collection: Collection
) -> Iterable[dict]:
) -> Iterable[Dict]:
"""Return a generator of record-type dictionary objects when running in incremental replication mode."""
if bookmark:
logger.info(f"using existing bookmark: {bookmark}")
Expand Down Expand Up @@ -222,12 +222,12 @@ def _get_records_incremental(
"to": None,
}
if should_add_metadata:
parsed_record["_sdc_batched_at"] = datetime.utcnow()
parsed_record["_sdc_batched_at"] = datetime.now(timezone.utc)
yield parsed_record

def _get_records_log_based(
self, bookmark: str, should_add_metadata: bool, collection: Collection
) -> Iterable[dict]:
) -> Iterable[Dict]:
"""Return a generator of record-type dictionary objects when running in log-based replication mode."""
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
change_stream_options: Dict[str, Union[str, Dict[str, str]]] = {"full_document": "updateLookup"}
Expand Down Expand Up @@ -336,16 +336,24 @@ def _get_records_log_based(
if record is not None:
operation_type = record["operationType"]
if operation_type not in operation_types_allowlist:
logger.info(f"Skipping record of operationType {operation_type} which is not in allowlist")
logger.warning(f"Skipping record of operationType {operation_type} which is not in allowlist")
continue
cluster_time: datetime = record["clusterTime"].as_datetime()
# fullDocument key is not present on delete events - if it is missing, fall back to documentKey
# instead. If that is missing, pass None/null to avoid raising an error.
document = record.get("fullDocument", record.get("documentKey", None))
# document: Dict = record.get("fullDocument", record.get("documentKey", None))
document: Optional[Dict]
if "fullDocument" in record:
document = record["fullDocument"]
elif "documentKey" in record:
document = record["documentKey"]
else:
document = None

object_id: Optional[ObjectId] = document["_id"] if document and "_id" in document else None
update_description: Optional[Dict] = None
if "updateDescription" in record:
update_description = record.get("updateDescription")
update_description = record["updateDescription"]
to_obj: Optional[Dict] = None
if "to" in record:
to_obj = {
Expand All @@ -367,13 +375,13 @@ def _get_records_log_based(
}
if should_add_metadata:
parsed_record["_sdc_extracted_at"] = cluster_time
parsed_record["_sdc_batched_at"] = datetime.utcnow()
parsed_record["_sdc_batched_at"] = datetime.now(timezone.utc)
if operation_type == "delete":
parsed_record["_sdc_deleted_at"] = cluster_time
yield parsed_record
has_seen_a_record = True

def get_records(self, context: dict | None) -> Iterable[dict]:
def get_records(self, context: Dict | None) -> Iterable[Dict]:
"""Return a generator of record-type dictionary objects."""
bookmark: str = self.get_starting_replication_key_value(context)
should_add_metadata: bool = self.config.get("add_record_metadata", False)
Expand Down

0 comments on commit 0207c76

Please sign in to comment.