diff --git a/build/lib/tap_snowflake/__init__.py b/build/lib/tap_snowflake/__init__.py new file mode 100644 index 0000000..6a0b005 --- /dev/null +++ b/build/lib/tap_snowflake/__init__.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python3 +# pylint: disable=missing-docstring,not-an-iterable,too-many-locals,too-many-arguments,too-many-branches,invalid-name,duplicate-code,too-many-statements + +import collections +import copy +import itertools +import re +import sys +import logging + +import singer +import singer.metrics as metrics +import singer.schema +import snowflake.connector +from singer import metadata +from singer import utils +from singer.catalog import Catalog, CatalogEntry +from singer.schema import Schema + +import tap_snowflake.sync_strategies.common as common +import tap_snowflake.sync_strategies.full_table as full_table +import tap_snowflake.sync_strategies.incremental as incremental +from tap_snowflake.connection import SnowflakeConnection + +LOGGER = singer.get_logger('tap_snowflake') + +# Max number of rows that a SHOW SCHEMAS|TABLES|COLUMNS can return. +# If more than this number of rows returned then tap-snowflake will raise TooManyRecordsException +SHOW_COMMAND_MAX_ROWS = 9999 + + +# Tone down snowflake connector logs noise +logging.getLogger('snowflake.connector').setLevel(logging.WARNING) + +Column = collections.namedtuple('Column', [ + 'table_catalog', + 'table_schema', + 'table_name', + 'column_name', + 'data_type', + 'character_maximum_length', + 'numeric_precision', + 'numeric_scale']) + +REQUIRED_CONFIG_KEYS = [ + 'account', + 'dbname', + 'user', + 'warehouse', + 'tables' +] + +# Snowflake data types +STRING_TYPES = set(['varchar', 'char', 'character', 'string', 'text']) +NUMBER_TYPES = set(['number', 'decimal', 'numeric']) +INTEGER_TYPES = set(['int', 'integer', 'bigint', 'smallint']) +FLOAT_TYPES = set(['float', 'float4', 'float8', 'real', 'double', 'double precision']) +DATETIME_TYPES = set(['datetime', 'timestamp', 'date', 'timestamp_ltz', 'timestamp_ntz', 'timestamp_tz']) +BINARY_TYPE = set(['binary', 'varbinary']) + + +def schema_for_column(c): + '''Returns the Schema object for the given Column.''' + data_type = c.data_type.lower() + + inclusion = 'available' + result = Schema(inclusion=inclusion) + + if data_type == 'boolean': + result.type = ['null', 'boolean'] + + elif data_type in INTEGER_TYPES: + result.type = ['null', 'number'] + + elif data_type in FLOAT_TYPES: + result.type = ['null', 'number'] + + elif data_type in NUMBER_TYPES: + result.type = ['null', 'number'] + + elif data_type in STRING_TYPES: + result.type = ['null', 'string'] + result.maxLength = c.character_maximum_length + + elif data_type in DATETIME_TYPES: + result.type = ['null', 'string'] + result.format = 'date-time' + + elif data_type == 'time': + result.type = ['null', 'string'] + result.format = 'time' + + elif data_type in BINARY_TYPE: + result.type = ['null', 'string'] + result.format = 'binary' + + else: + result = Schema(None, + inclusion='unsupported', + description='Unsupported data type {}'.format(data_type)) + return result + + +def create_column_metadata(cols): + mdata = {} + mdata = metadata.write(mdata, (), 'selected-by-default', False) + for c in cols: + schema = schema_for_column(c) + mdata = metadata.write(mdata, + ('properties', c.column_name), + 'selected-by-default', + schema.inclusion != 'unsupported') + mdata = metadata.write(mdata, + ('properties', c.column_name), + 'sql-datatype', + c.data_type.lower()) + + return metadata.to_list(mdata) + + +def get_table_columns(snowflake_conn, tables): + """Get column definitions of a list of tables + + It's using SHOW commands instead of INFORMATION_SCHEMA views because information_schemas views are slow + and can cause unexpected exception of: + Information schema query returned too much data. Please repeat query with more selective predicates. + """ + table_columns = [] + for table in tables: + queries = [] + + LOGGER.info('Getting column information for %s...', table) + + # Get column data types by SHOW commands + show_columns = f'SHOW COLUMNS IN TABLE {table}' + + # Convert output of SHOW commands to tables and use SQL joins to get every required information + select = """ + WITH + show_columns AS (SELECT * FROM TABLE(RESULT_SCAN(%(LAST_QID)s))) + SELECT show_columns."database_name" AS table_catalog + ,show_columns."schema_name" AS table_schema + ,show_columns."table_name" AS table_name + ,show_columns."column_name" AS column_name + -- ---------------------------------------------------------------------------------------- + -- Character and numeric columns display their generic data type rather than their defined + -- data type (i.e. TEXT for all character types, FIXED for all fixed-point numeric types, + -- and REAL for all floating-point numeric types). + -- + -- Further info at https://docs.snowflake.net/manuals/sql-reference/sql/show-columns.html + -- ---------------------------------------------------------------------------------------- + ,CASE PARSE_JSON(show_columns."data_type"):type::varchar + WHEN 'FIXED' THEN 'NUMBER' + WHEN 'REAL' THEN 'FLOAT' + ELSE PARSE_JSON("data_type"):type::varchar + END data_type + ,PARSE_JSON(show_columns."data_type"):length::number AS character_maximum_length + ,PARSE_JSON(show_columns."data_type"):precision::number AS numeric_precision + ,PARSE_JSON(show_columns."data_type"):scale::number AS numeric_scale + FROM show_columns + """ + queries.extend([show_columns, select]) + + # Run everything in one transaction + columns = snowflake_conn.query(queries, max_records=SHOW_COMMAND_MAX_ROWS) + table_columns.extend(columns) + + return table_columns + + +def discover_catalog(snowflake_conn, config): + """Returns a Catalog describing the structure of the database.""" + tables = config.get('tables').split(',') + sql_columns = get_table_columns(snowflake_conn, tables) + + table_info = {} + columns = [] + for sql_col in sql_columns: + catalog = sql_col['TABLE_CATALOG'] + schema = sql_col['TABLE_SCHEMA'] + table_name = sql_col['TABLE_NAME'] + + if catalog not in table_info: + table_info[catalog] = {} + + if schema not in table_info[catalog]: + table_info[catalog][schema] = {} + + table_info[catalog][schema][table_name] = { + 'row_count': sql_col.get('ROW_COUNT'), + 'is_view': sql_col.get('TABLE_TYPE') == 'VIEW' + } + + columns.append(Column( + table_catalog=catalog, + table_schema=schema, + table_name=table_name, + column_name=sql_col['COLUMN_NAME'], + data_type=sql_col['DATA_TYPE'], + character_maximum_length=sql_col['CHARACTER_MAXIMUM_LENGTH'], + numeric_precision=sql_col['NUMERIC_PRECISION'], + numeric_scale=sql_col['NUMERIC_SCALE'] + )) + + entries = [] + for (k, cols) in itertools.groupby(columns, lambda c: (c.table_catalog, c.table_schema, c.table_name)): + cols = list(cols) + (table_catalog, table_schema, table_name) = k + schema = Schema(type='object', + properties={c.column_name: schema_for_column(c) for c in cols}) + md = create_column_metadata(cols) + md_map = metadata.to_map(md) + + md_map = metadata.write(md_map, (), 'database-name', table_catalog) + md_map = metadata.write(md_map, (), 'schema-name', table_schema) + + if ( + table_catalog in table_info and + table_schema in table_info[table_catalog] and + table_name in table_info[table_catalog][table_schema] + ): + # Row Count of views returns NULL - Transform it to not null integer by defaults to 0 + row_count = table_info[table_catalog][table_schema][table_name].get('row_count', 0) or 0 + is_view = table_info[table_catalog][table_schema][table_name]['is_view'] + + md_map = metadata.write(md_map, (), 'row-count', row_count) + md_map = metadata.write(md_map, (), 'is-view', is_view) + + entry = CatalogEntry( + table=table_name, + stream=table_name, + metadata=metadata.to_list(md_map), + tap_stream_id=common.generate_tap_stream_id(table_catalog, table_schema, table_name), + schema=schema) + + entries.append(entry) + + return Catalog(entries) + + +def do_discover(snowflake_conn, config): + discover_catalog(snowflake_conn, config).dump() + + +# pylint: disable=fixme +# TODO: Maybe put in a singer-db-utils library. +def desired_columns(selected, table_schema): + """Return the set of column names we need to include in the SELECT. + + selected - set of column names marked as selected in the input catalog + table_schema - the most recently discovered Schema for the table + """ + all_columns = set() + available = set() + automatic = set() + unsupported = set() + + for column, column_schema in table_schema.properties.items(): + all_columns.add(column) + inclusion = column_schema.inclusion + if inclusion == 'automatic': + automatic.add(column) + elif inclusion == 'available': + available.add(column) + elif inclusion == 'unsupported': + unsupported.add(column) + else: + raise Exception('Unknown inclusion ' + inclusion) + + selected_but_unsupported = selected.intersection(unsupported) + if selected_but_unsupported: + LOGGER.warning( + 'Columns %s were selected but are not supported. Skipping them.', + selected_but_unsupported) + + selected_but_nonexistent = selected.difference(all_columns) + if selected_but_nonexistent: + LOGGER.warning( + 'Columns %s were selected but do not exist.', + selected_but_nonexistent) + + not_selected_but_automatic = automatic.difference(selected) + if not_selected_but_automatic: + LOGGER.warning( + 'Columns %s are primary keys but were not selected. Adding them.', + not_selected_but_automatic) + + return selected.intersection(available).union(automatic) + + +def resolve_catalog(discovered_catalog, streams_to_sync): + result = Catalog(streams=[]) + + # Iterate over the streams in the input catalog and match each one up + # with the same stream in the discovered catalog. + for catalog_entry in streams_to_sync: + catalog_metadata = metadata.to_map(catalog_entry.metadata) + replication_key = catalog_metadata.get((), {}).get('replication-key') + + discovered_table = discovered_catalog.get_stream(catalog_entry.tap_stream_id) + database_name = common.get_database_name(catalog_entry) + + if not discovered_table: + LOGGER.warning('Database %s table %s was selected but does not exist', + database_name, catalog_entry.table) + continue + + selected = {k for k, v in catalog_entry.schema.properties.items() + if common.property_is_selected(catalog_entry, k) or k == replication_key} + + # These are the columns we need to select + columns = desired_columns(selected, discovered_table.schema) + + result.streams.append(CatalogEntry( + tap_stream_id=catalog_entry.tap_stream_id, + metadata=catalog_entry.metadata, + stream=catalog_entry.tap_stream_id, + table=catalog_entry.table, + schema=Schema( + type='object', + properties={col: discovered_table.schema.properties[col] + for col in columns} + ) + )) + + return result + + +def get_streams(snowflake_conn, catalog, config, state): + """Returns the Catalog of data we're going to sync for all SELECT-based + streams (i.e. INCREMENTAL and FULL_TABLE that require a historical + sync). + + Using the Catalog provided from the input file, this function will return a + Catalog representing exactly which tables and columns that will be emitted + by SELECT-based syncs. This is achieved by comparing the input Catalog to a + freshly discovered Catalog to determine the resulting Catalog. + + The resulting Catalog will include the following any streams marked as + "selected" that currently exist in the database. Columns marked as "selected" + and those labled "automatic" (e.g. primary keys and replication keys) will be + included. Streams will be prioritized in the following order: + 1. currently_syncing if it is SELECT-based + 2. any streams that do not have state + 3. any streams that do not have a replication method of LOG_BASED + """ + discovered = discover_catalog(snowflake_conn, config) + + # Filter catalog to include only selected streams + # pylint: disable=unnecessary-lambda + selected_streams = list(filter(lambda s: common.stream_is_selected(s), catalog.streams)) + streams_with_state = [] + streams_without_state = [] + + for stream in selected_streams: + stream_state = state.get('bookmarks', {}).get(stream.tap_stream_id) + + if not stream_state: + streams_without_state.append(stream) + else: + streams_with_state.append(stream) + + # If the state says we were in the middle of processing a stream, skip + # to that stream. Then process streams without prior state and finally + # move onto streams with state (i.e. have been synced in the past) + currently_syncing = singer.get_currently_syncing(state) + + # prioritize streams that have not been processed + ordered_streams = streams_without_state + streams_with_state + + if currently_syncing: + currently_syncing_stream = list(filter( + lambda s: s.tap_stream_id == currently_syncing, streams_with_state)) + + non_currently_syncing_streams = list(filter(lambda s: s.tap_stream_id != currently_syncing, ordered_streams)) + + streams_to_sync = currently_syncing_stream + non_currently_syncing_streams + else: + # prioritize streams that have not been processed + streams_to_sync = ordered_streams + + return resolve_catalog(discovered, streams_to_sync) + + +def write_schema_message(catalog_entry, bookmark_properties=None): + key_properties = common.get_key_properties(catalog_entry) + + singer.write_message(singer.SchemaMessage( + stream=catalog_entry.stream, + schema=catalog_entry.schema.to_dict(), + key_properties=key_properties, + bookmark_properties=bookmark_properties + )) + + +def do_sync_incremental(snowflake_conn, catalog_entry, state, columns): + LOGGER.info('Stream %s is using incremental replication', catalog_entry.stream) + + md_map = metadata.to_map(catalog_entry.metadata) + replication_key = md_map.get((), {}).get('replication-key') + + if not replication_key: + raise Exception(f'Cannot use INCREMENTAL replication for table ({catalog_entry.stream}) without a replication ' + f'key.') + + write_schema_message(catalog_entry=catalog_entry, + bookmark_properties=[replication_key]) + + incremental.sync_table(snowflake_conn, catalog_entry, state, columns) + + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + +def do_sync_full_table(snowflake_conn, catalog_entry, state, columns): + LOGGER.info('Stream %s is using full table replication', catalog_entry.stream) + + write_schema_message(catalog_entry) + + stream_version = common.get_stream_version(catalog_entry.tap_stream_id, state) + + full_table.sync_table(snowflake_conn, catalog_entry, state, columns, stream_version) + + # Prefer initial_full_table_complete going forward + singer.clear_bookmark(state, catalog_entry.tap_stream_id, 'version') + + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'initial_full_table_complete', + True) + + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + +def sync_streams(snowflake_conn, catalog, state): + for catalog_entry in catalog.streams: + columns = list(catalog_entry.schema.properties.keys()) + + if not columns: + LOGGER.warning('There are no columns selected for stream %s, skipping it.', catalog_entry.stream) + continue + + state = singer.set_currently_syncing(state, catalog_entry.tap_stream_id) + + # Emit a state message to indicate that we've started this stream + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + md_map = metadata.to_map(catalog_entry.metadata) + + replication_method = md_map.get((), {}).get('replication-method') + + database_name = common.get_database_name(catalog_entry) + schema_name = common.get_schema_name(catalog_entry) + + with metrics.job_timer('sync_table') as timer: + timer.tags['database'] = database_name + timer.tags['table'] = catalog_entry.table + + LOGGER.info('Beginning to sync %s.%s.%s', database_name, schema_name, catalog_entry.table) + + if replication_method == 'INCREMENTAL': + do_sync_incremental(snowflake_conn, catalog_entry, state, columns) + elif replication_method == 'FULL_TABLE': + do_sync_full_table(snowflake_conn, catalog_entry, state, columns) + else: + raise Exception('Only INCREMENTAL and FULL TABLE replication methods are supported') + + state = singer.set_currently_syncing(state, None) + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + +def do_sync(snowflake_conn, config, catalog, state): + catalog = get_streams(snowflake_conn, catalog, config, state) + sync_streams(snowflake_conn, catalog, state) + + +def main_impl(): + args = utils.parse_args(REQUIRED_CONFIG_KEYS) + + snowflake_conn = SnowflakeConnection(args.config) + + if args.discover: + do_discover(snowflake_conn, args.config) + elif args.catalog: + state = args.state or {} + do_sync(snowflake_conn, args.config, args.catalog, state) + elif args.properties: + catalog = Catalog.from_dict(args.properties) + state = args.state or {} + do_sync(snowflake_conn, args.config, catalog, state) + else: + LOGGER.info('No properties were selected') + + +def main(): + try: + main_impl() + except Exception as exc: + LOGGER.critical(exc) + raise exc diff --git a/build/lib/tap_snowflake/connection.py b/build/lib/tap_snowflake/connection.py new file mode 100644 index 0000000..191ac2e --- /dev/null +++ b/build/lib/tap_snowflake/connection.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +from typing import Union, List, Dict + +import backoff +import singer +import sys +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +import snowflake.connector + +LOGGER = singer.get_logger('tap_snowflake') + + +class TooManyRecordsException(Exception): + """Exception to raise when query returns more records than max_records""" + + +def retry_pattern(): + """Retry pattern decorator used when connecting to snowflake + """ + return backoff.on_exception(backoff.expo, + snowflake.connector.errors.OperationalError, + max_tries=5, + on_backoff=log_backoff_attempt, + factor=2) + + +def log_backoff_attempt(details): + """Log backoff attempts used by retry_pattern + """ + LOGGER.info('Error detected communicating with Snowflake, triggering backoff: %d try', details.get('tries')) + + +def validate_config(config): + """Validate configuration dictionary""" + errors = [] + required_config_keys = [ + 'account', + 'dbname', + 'user', + 'warehouse', + 'tables' + ] + + # Check if mandatory keys exist + for k in required_config_keys: + if not config.get(k, None): + errors.append(f'Required key is missing from config: [{k}]') + + possible_authentication_keys = [ + 'password', + 'private_key_path' + ] + if not any(config.get(k, None) for k in possible_authentication_keys): + errors.append( + f'Required authentication key missing. Existing methods: {",".join(possible_authentication_keys)}') + + return errors + + +class SnowflakeConnection: + """Class to manage connection to snowflake data warehouse""" + + def __init__(self, connection_config): + """ + connection_config: Snowflake connection details + """ + self.connection_config = connection_config + config_errors = validate_config(connection_config) + if len(config_errors) == 0: + self.connection_config = connection_config + else: + LOGGER.error('Invalid configuration:\n * %s', '\n * '.join(config_errors)) + sys.exit(1) + + def get_private_key(self): + """ + Get private key from the right location + """ + if self.connection_config.get('private_key_path'): + try: + encoded_passphrase = self.connection_config['private_key_passphrase'].encode() + except KeyError: + encoded_passphrase = None + + with open(self.connection_config['private_key_path'], 'rb') as key: + p_key= serialization.load_pem_private_key( + key.read(), + password=encoded_passphrase, + backend=default_backend() + ) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption()) + return pkb + + return None + + def open_connection(self): + """Connect to snowflake database""" + return snowflake.connector.connect( + user=self.connection_config['user'], + password=self.connection_config.get('password', None), + private_key=self.get_private_key(), + account=self.connection_config['account'], + database=self.connection_config['dbname'], + warehouse=self.connection_config['warehouse'], + insecure_mode=self.connection_config.get('insecure_mode', False) + # Use insecure mode to avoid "Failed to get OCSP response" warnings + # insecure_mode=True + ) + + @retry_pattern() + def connect_with_backoff(self): + """Connect to snowflake database and retry automatically a few times if fails""" + return self.open_connection() + + def query(self, query: Union[List[str], str], params: Dict = None, max_records=0): + """Run a query in snowflake""" + result = [] + + if params is None: + params = {} + else: + if 'LAST_QID' in params: + LOGGER.warning('LAST_QID is a reserved prepared statement parameter name, ' + 'it will be overridden with each executed query!') + + with self.connect_with_backoff() as connection: + with connection.cursor(snowflake.connector.DictCursor) as cur: + + # Run every query in one transaction if query is a list of SQL + if isinstance(query, list): + cur.execute('START TRANSACTION') + queries = query + else: + queries = [query] + + qid = None + + for sql in queries: + LOGGER.debug('Running query: %s', sql) + + # update the LAST_QID + params['LAST_QID'] = qid + + cur.execute(sql, params) + qid = cur.sfqid + + # Raise exception if returned rows greater than max allowed records + if 0 < max_records < cur.rowcount: + raise TooManyRecordsException( + f'Query returned too many records. This query can return max {max_records} records') + + if cur.rowcount > 0: + result = cur.fetchall() + + return result diff --git a/build/lib/tap_snowflake/sync_strategies/__init__.py b/build/lib/tap_snowflake/sync_strategies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/tap_snowflake/sync_strategies/common.py b/build/lib/tap_snowflake/sync_strategies/common.py new file mode 100644 index 0000000..c0ba297 --- /dev/null +++ b/build/lib/tap_snowflake/sync_strategies/common.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# pylint: disable=too-many-arguments,duplicate-code,too-many-locals + +import copy +import datetime +import singer +import time + +import singer.metrics as metrics +from singer import metadata +from singer import utils + +LOGGER = singer.get_logger('tap_snowflake') + +def escape(string): + """Escape strings to be SQL safe""" + if '"' in string: + raise Exception("Can't escape identifier {} because it contains a backtick" + .format(string)) + return '"{}"'.format(string) + + +def generate_tap_stream_id(catalog_name, schema_name, table_name): + """Generate tap stream id as appears in properties.json""" + return catalog_name + '-' + schema_name + '-' + table_name + + +def get_stream_version(tap_stream_id, state): + """Get stream version from bookmark""" + stream_version = singer.get_bookmark(state, tap_stream_id, 'version') + + if stream_version is None: + stream_version = int(time.time() * 1000) + + return stream_version + + +def stream_is_selected(stream): + """Detect if stream is selected to sync""" + md_map = metadata.to_map(stream.metadata) + selected_md = metadata.get(md_map, (), 'selected') + + return selected_md + + +def property_is_selected(stream, property_name): + """Detect if field is selected to sync""" + md_map = metadata.to_map(stream.metadata) + return singer.should_sync_field( + metadata.get(md_map, ('properties', property_name), 'inclusion'), + metadata.get(md_map, ('properties', property_name), 'selected'), + True) + + +def get_is_view(catalog_entry): + """Detect if stream is a view""" + md_map = metadata.to_map(catalog_entry.metadata) + + return md_map.get((), {}).get('is-view') + + +def get_database_name(catalog_entry): + """Get database name from catalog""" + md_map = metadata.to_map(catalog_entry.metadata) + + return md_map.get((), {}).get('database-name') + + +def get_schema_name(catalog_entry): + """Get schema name from catalog""" + md_map = metadata.to_map(catalog_entry.metadata) + + return md_map.get((), {}).get('schema-name') + + +def get_key_properties(catalog_entry): + """Get key properties from catalog""" + catalog_metadata = metadata.to_map(catalog_entry.metadata) + stream_metadata = catalog_metadata.get((), {}) + + is_view = get_is_view(catalog_entry) + + if is_view: + key_properties = stream_metadata.get('view-key-properties', []) + else: + key_properties = stream_metadata.get('table-key-properties', []) + + return key_properties + + +def generate_select_sql(catalog_entry, columns): + """Generate SQL to extract data froom snowflake""" + database_name = get_database_name(catalog_entry) + schema_name = get_schema_name(catalog_entry) + escaped_db = escape(database_name) + escaped_schema = escape(schema_name) + escaped_table = escape(catalog_entry.table) + escaped_columns = [] + + for col_name in columns: + escaped_col = escape(col_name) + + # fetch the column type format from the json schema alreay built + property_format = catalog_entry.schema.properties[col_name].format + + # if the column format is binary, fetch the hexified value + if property_format == 'binary': + escaped_columns.append(f'hex_encode({escaped_col}) as {escaped_col}') + else: + escaped_columns.append(escaped_col) + + select_sql = f'SELECT {",".join(escaped_columns)} FROM {escaped_db}.{escaped_schema}.{escaped_table}' + + # escape percent signs + select_sql = select_sql.replace('%', '%%') + return select_sql + + +# pylint: disable=too-many-branches +def row_to_singer_record(catalog_entry, version, row, columns, time_extracted): + """Transform SQL row to singer compatible record message""" + row_to_persist = () + for idx, elem in enumerate(row): + property_type = catalog_entry.schema.properties[columns[idx]].type + if isinstance(elem, datetime.datetime): + row_to_persist += (elem.isoformat() + '+00:00',) + + elif isinstance(elem, datetime.date): + row_to_persist += (elem.isoformat() + 'T00:00:00+00:00',) + + elif isinstance(elem, datetime.timedelta): + epoch = datetime.datetime.utcfromtimestamp(0) + timedelta_from_epoch = epoch + elem + row_to_persist += (timedelta_from_epoch.isoformat() + '+00:00',) + + elif isinstance(elem, datetime.time): + row_to_persist += (str(elem),) + + elif isinstance(elem, bytes): + # for BIT value, treat 0 as False and anything else as True + if 'boolean' in property_type: + boolean_representation = elem != b'\x00' + row_to_persist += (boolean_representation,) + else: + row_to_persist += (elem.hex(),) + + elif 'boolean' in property_type or property_type == 'boolean': + if elem is None: + boolean_representation = None + elif elem == 0: + boolean_representation = False + else: + boolean_representation = True + row_to_persist += (boolean_representation,) + + else: + row_to_persist += (elem,) + rec = dict(zip(columns, row_to_persist)) + + return singer.RecordMessage( + stream=catalog_entry.stream, + record=rec, + version=version, + time_extracted=time_extracted) + + +def whitelist_bookmark_keys(bookmark_key_set, tap_stream_id, state): + """...""" + for bookmark_key in [non_whitelisted_bookmark_key + for non_whitelisted_bookmark_key + in state.get('bookmarks', {}).get(tap_stream_id, {}).keys() + if non_whitelisted_bookmark_key not in bookmark_key_set]: + singer.clear_bookmark(state, tap_stream_id, bookmark_key) + + +def sync_query(cursor, catalog_entry, state, select_sql, columns, stream_version, params): + """...""" + replication_key = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key') + + time_extracted = utils.now() + + LOGGER.info('Running %s', select_sql) + cursor.execute(select_sql, params) + + row = cursor.fetchone() + rows_saved = 0 + + database_name = get_database_name(catalog_entry) + + with metrics.record_counter(None) as counter: + counter.tags['database'] = database_name + counter.tags['table'] = catalog_entry.table + + while row: + counter.increment() + rows_saved += 1 + record_message = row_to_singer_record(catalog_entry, + stream_version, + row, + columns, + time_extracted) + singer.write_message(record_message) + + md_map = metadata.to_map(catalog_entry.metadata) + stream_metadata = md_map.get((), {}) + replication_method = stream_metadata.get('replication-method') + + if replication_method == 'FULL_TABLE': + key_properties = get_key_properties(catalog_entry) + + max_pk_values = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'max_pk_values') + + if max_pk_values: + last_pk_fetched = {k:v for k, v in record_message.record.items() + if k in key_properties} + + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'last_pk_fetched', + last_pk_fetched) + + elif replication_method == 'INCREMENTAL': + if replication_key is not None: + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key', + replication_key) + + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key_value', + record_message.record[replication_key]) + if rows_saved % 1000 == 0: + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + row = cursor.fetchone() + + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) diff --git a/build/lib/tap_snowflake/sync_strategies/full_table.py b/build/lib/tap_snowflake/sync_strategies/full_table.py new file mode 100644 index 0000000..73703c9 --- /dev/null +++ b/build/lib/tap_snowflake/sync_strategies/full_table.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# pylint: disable=duplicate-code,too-many-locals,simplifiable-if-expression + +import singer +import tap_snowflake.sync_strategies.common as common + +LOGGER = singer.get_logger('tap_snowflake') + +BOOKMARK_KEYS = {'last_pk_fetched', 'max_pk_values', 'version', 'initial_full_table_complete'} + + +def get_max_pk_values(cursor, catalog_entry): + """Get actual max primary key values from database""" + database_name = common.get_database_name(catalog_entry) + escaped_db = common.escape(database_name) + escaped_table = common.escape(catalog_entry.table) + + key_properties = common.get_key_properties(catalog_entry) + escaped_columns = [common.escape(c) for c in key_properties] + + sql = """SELECT {} + FROM {}.{} + ORDER BY {} + LIMIT 1 + """ + + select_column_clause = ', '.join(escaped_columns) + order_column_clause = ', '.join([pk + ' DESC' for pk in escaped_columns]) + + cursor.execute(sql.format(select_column_clause, + escaped_db, + escaped_table, + order_column_clause)) + result = cursor.fetchone() + + if result: + max_pk_values = dict(zip(key_properties, result)) + else: + max_pk_values = {} + + return max_pk_values + + +def generate_pk_clause(catalog_entry, state): + """Generate primary key where clause to SQL select""" + key_properties = common.get_key_properties(catalog_entry) + escaped_columns = [common.escape(c) for c in key_properties] + + max_pk_values = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'max_pk_values') + + last_pk_fetched = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'last_pk_fetched') + + if last_pk_fetched: + pk_comparisons = ['({} > {} AND {} <= {})'.format(common.escape(pk), + last_pk_fetched[pk], + common.escape(pk), + max_pk_values[pk]) + for pk in key_properties] + else: + pk_comparisons = [f'{common.escape(pk)} <= {max_pk_values[pk]}' for pk in key_properties] + + sql = ' WHERE {} ORDER BY {} ASC'.format(' AND '.join(pk_comparisons), + ', '.join(escaped_columns)) + + return sql + + +def sync_table(snowflake_conn, catalog_entry, state, columns, stream_version): + """Sync table with FULL_TABLE""" + common.whitelist_bookmark_keys(BOOKMARK_KEYS, catalog_entry.tap_stream_id, state) + + bookmark = state.get('bookmarks', {}).get(catalog_entry.tap_stream_id, {}) + version_exists = True if 'version' in bookmark else False + + initial_full_table_complete = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'initial_full_table_complete') + + state_version = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'version') + + activate_version_message = singer.ActivateVersionMessage( + stream=catalog_entry.stream, + version=stream_version + ) + + # For the initial replication, emit an ACTIVATE_VERSION message + # at the beginning so the records show up right away. + if not initial_full_table_complete and not (version_exists and state_version is None): + singer.write_message(activate_version_message) + + with snowflake_conn.connect_with_backoff() as open_conn: + with open_conn.cursor() as cur: + select_sql = common.generate_select_sql(catalog_entry, columns) + params = {} + + common.sync_query(cur, + catalog_entry, + state, + select_sql, + columns, + stream_version, + params) + + # clear max pk value and last pk fetched upon successful sync + singer.clear_bookmark(state, catalog_entry.tap_stream_id, 'max_pk_values') + singer.clear_bookmark(state, catalog_entry.tap_stream_id, 'last_pk_fetched') + + singer.write_message(activate_version_message) diff --git a/build/lib/tap_snowflake/sync_strategies/incremental.py b/build/lib/tap_snowflake/sync_strategies/incremental.py new file mode 100644 index 0000000..4cd25aa --- /dev/null +++ b/build/lib/tap_snowflake/sync_strategies/incremental.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# pylint: disable=duplicate-code + +import pendulum +import singer +from singer import metadata +import tap_snowflake.sync_strategies.common as common + +LOGGER = singer.get_logger('tap_snowflake') + +BOOKMARK_KEYS = {'replication_key', 'replication_key_value', 'version'} + +def sync_table(snowflake_conn, catalog_entry, state, columns): + """Sync table incrementally""" + common.whitelist_bookmark_keys(BOOKMARK_KEYS, catalog_entry.tap_stream_id, state) + + catalog_metadata = metadata.to_map(catalog_entry.metadata) + stream_metadata = catalog_metadata.get((), {}) + + replication_key_metadata = stream_metadata.get('replication-key') + replication_key_state = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key') + + replication_key_value = None + + if replication_key_metadata == replication_key_state: + replication_key_value = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key_value') + else: + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key', + replication_key_metadata) + state = singer.clear_bookmark(state, catalog_entry.tap_stream_id, 'replication_key_value') + + stream_version = common.get_stream_version(catalog_entry.tap_stream_id, state) + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'version', + stream_version) + + activate_version_message = singer.ActivateVersionMessage( + stream=catalog_entry.stream, + version=stream_version + ) + + singer.write_message(activate_version_message) + + select_sql = common.generate_select_sql(catalog_entry, columns) + params = {} + + with snowflake_conn.connect_with_backoff() as open_conn: + with open_conn.cursor() as cur: + select_sql = common.generate_select_sql(catalog_entry, columns) + params = {} + + if replication_key_value is not None: + if catalog_entry.schema.properties[replication_key_metadata].format == 'date-time': + replication_key_value = pendulum.parse(replication_key_value) + + # pylint: disable=duplicate-string-formatting-argument + select_sql += ' WHERE "{}" >= \'{}\' ORDER BY "{}" ASC'.format( + replication_key_metadata, + replication_key_value, + replication_key_metadata) + + elif replication_key_metadata is not None: + select_sql += ' ORDER BY "{}" ASC'.format(replication_key_metadata) + + common.sync_query(cur, + catalog_entry, + state, + select_sql, + columns, + stream_version, + params) diff --git a/setup.py b/setup.py index 948e499..73b70ef 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,8 @@ ], py_modules=['tap_snowflake'], install_requires=[ - 'pipelinewise-singer-python==1.*', - 'snowflake-connector-python[pandas]==2.4.*', + 'pipelinewise-singer-python>=1.*', + 'snowflake-connector-python[pandas]>=2.4.*', 'pendulum==1.2.0' ], extras_require={