Skip to content

Commit

Permalink
add skip_hot_partitions config
Browse files Browse the repository at this point in the history
  • Loading branch information
MindaugasN committed Jan 8, 2024
1 parent 3a272e8 commit 6835e8d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pipx install git+https://github.com/datarts-tech/tap-cassandra.git@main
| max_attempts | False | 5 | Should be a total number of attempts to be made before giving up. |
| protocol_version | False | 65 | The maximum version of the native protocol to use. |
| fetch_size | False | 10000 | The fetch size when syncing data from Cassandra. |
| skip_hot_partitions | False | False | When set to `True` skipping partitions when faced ReadTimout or ReadFailure errors. |
| stream_maps | False | None | Config object for stream maps capability. For more information check out [Stream Maps](https://sdk.meltano.com/en/latest/stream_maps.html). |
| stream_map_config | False | None | User-defined config values to be used within map expressions. |
| flattening_enabled | False | None | 'True' to enable schema flattening and automatically expand nested properties. |
Expand Down
57 changes: 54 additions & 3 deletions tap_cassandra/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Custom client handling, including CassandraStream base class."""

import time
import logging

from singer_sdk import typing as th
from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema

from cassandra import ReadFailure, ReadTimeout
from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
from cassandra.auth import PlainTextAuthProvider
from cassandra.policies import (
Expand Down Expand Up @@ -187,8 +189,12 @@ def _disconnect(self):
self.cluster.shutdown()

def execute(self, query):
"""Method to execute the query and return the output."""

"""Method to execute the query and return the output.
Args:
query: Cassandra CQL query to execute
"""

try:
res = self.session.execute(self.query_statement(query, self.config.get('fetch_size')))
while res.has_more_pages or res.current_rows:
Expand All @@ -202,6 +208,50 @@ def execute(self, query):
finally:
self._disconnect()

def execute_with_skip(self, query, key_col):
"""Method to execute the query and return the output.
Handles ReadTimeout and ReadFailure to skip hot partitions.
Args:
query: Cassandra CQL query to execute
key_col: first partition_key of a table
"""

# Retry for ReadTimeout and ReadFailure
sleep_time_seconds = 30
retry = 0
max_retries = 3
while retry < max_retries:
try:
batch = None
res = self.session.execute(self.query_statement(query, self.config.get('fetch_size')))
while res.has_more_pages or res.current_rows:
batch = res.current_rows
self.logger.info(f'{len(batch)} row(s) fetched.')
for row in batch:
yield row
res.fetch_next_page()
self._disconnect()
break
except (ReadTimeout, ReadFailure) as re:
retry += 1
if not batch:
res = self.session.execute(self.query_statement(query, 1))
batch = res.current_rows
self.logger.info(f'{len(batch)} row(s) fetched.')
last_key = batch[-1][key_col]
self.logger.info(f'Skipping {key_col} = {last_key}')
# Remove any filters done for a query
base_query = query.lower().split('where')[0].rstrip()
query = base_query + f" where token({key_col}) > token({last_key})"
print(f'Sleeping for {sleep_time_seconds} before retry')
self.logger.info(f'Sleeping for {sleep_time_seconds} before retry {retry} out of {max_retries}.')
time.sleep(sleep_time_seconds)
except Exception as e:
self._disconnect()
raise(e)

def discover_catalog_entry(
self,
table_name: str
Expand All @@ -214,7 +264,7 @@ def discover_catalog_entry(
Returns:
`CatalogEntry` object for the given table
"""

self.logger.info('discover_catalog_entry called.')
table_schema = th.PropertiesList()
partition_keys = list()
clustering_keys = list()
Expand Down Expand Up @@ -306,6 +356,7 @@ def discover_catalog_entries(self) -> list[dict]:
where keyspace_name = '{self.config.get('keyspace')}'
'''
for table in self.session.execute(self.query_statement(table_query, self.config.get('fetch_size'))):
# for table in self.session.execute(self.query_statement(table_query, self.config.get('fetch_size'))):
catalog_entry = self.discover_catalog_entry(table['table_name'])
result.append(catalog_entry.to_dict())

Expand Down
9 changes: 7 additions & 2 deletions tap_cassandra/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,10 @@ def get_records(self, context):
selected_column_string = ','.join(selected_column_names) if selected_column_names else '*'

cql = f"select {selected_column_string} from {self.name.split('-')[1]}"
for record in self.connector.execute(cql):
yield record

if self.config.get('skip_hot_partitions'):
for row in self.execute_with_skip(cql, self.catalog_entry['key_properties'][0]):
yield row
else:
for row in self.execute(cql):
yield row
7 changes: 7 additions & 0 deletions tap_cassandra/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ class TapCassandra(SQLTap):
default=10000,
description="The fetch size when syncing data from Cassandra.",
),
th.Property(
"skip_hot_partitions",
th.BoolType,
required=False,
default=False,
description="When set to `True` skipping partitions when faced ReadTimout or ReadFailure errors.",
),
).to_dict()

@property
Expand Down

0 comments on commit 6835e8d

Please sign in to comment.