From 60768f8505085d3f8d40f43b48eae45a6dde3294 Mon Sep 17 00:00:00 2001 From: Peter Kosztolanyi Date: Sun, 2 Jun 2019 16:07:48 +0100 Subject: [PATCH] initial commit --- .gitignore | 32 ++ LICENSE | 9 + README.md | 93 ++++ requirements.txt | 4 + setup.py | 23 + tap_snowflake/__init__.py | 473 +++++++++++++++++++ tap_snowflake/connection.py | 85 ++++ tap_snowflake/sync_strategies/__init__.py | 0 tap_snowflake/sync_strategies/common.py | 217 +++++++++ tap_snowflake/sync_strategies/full_table.py | 118 +++++ tap_snowflake/sync_strategies/incremental.py | 78 +++ tests/test_tap_snowflake.py | 162 +++++++ tests/utils.py | 68 +++ 13 files changed, 1362 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 tap_snowflake/__init__.py create mode 100644 tap_snowflake/connection.py create mode 100644 tap_snowflake/sync_strategies/__init__.py create mode 100644 tap_snowflake/sync_strategies/common.py create mode 100644 tap_snowflake/sync_strategies/full_table.py create mode 100644 tap_snowflake/sync_strategies/incremental.py create mode 100644 tests/test_tap_snowflake.py create mode 100644 tests/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2661658 --- /dev/null +++ b/.gitignore @@ -0,0 +1,32 @@ +# IDE +.vscode +.idea/* + + +# Python +__pycache__/ +*.py[cod] +*$py.class +.virtualenvs +*.egg-info/ +*__pycache__/ +*~ +dist/ + +# Singer JSON files +properties.json +config.json +state.json + +*.db +.DS_Store +venv +env +blog_old.md +node_modules +*.pyc +tmp + +# Docs +docs/_build/ +docs/_templates/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6675fc3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2019 TransferWise Ltd. (https://transferwise.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..119dcea --- /dev/null +++ b/README.md @@ -0,0 +1,93 @@ +# pipelinewise-tap-snowflake + +[![PyPI version](https://badge.fury.io/py/pipelinewise-tap-snowflake.svg)](https://badge.fury.io/py/pipelinewise-tap-snowflake) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pipelinewise-tap-snowflake.svg)](https://pypi.org/project/pipelinewise-tap-snowflake/) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + +[Singer](https://www.singer.io/) tap that extracts data from a [Snowflake](https://www.snowflake.com/) database and produces JSON-formatted data following the [Singer spec](https://github.com/singer-io/getting-started/blob/master/docs/SPEC.md). + +This is a [PipelineWise](https://transferwise.github.io/pipelinewise) compatible tap connector. + +## How to use it + +The recommended method of running this tap is to use it from [PipelineWise](https://transferwise.github.io/pipelinewise). When running it from PipelineWise you don't need to configure this tap with JSON files and most of things are automated. Please check the related documentation at [Tap Snowflake](https://transferwise.github.io/pipelinewise/connectors/taps/snowflake.html) + +If you want to run this [Singer Tap](https://singer.io) independently please read further. + +### Install and Run + +First, make sure Python 3 is installed on your system or follow these +installation instructions for [Mac](http://docs.python-guide.org/en/latest/starting/install3/osx/) or +[Ubuntu](https://www.digitalocean.com/community/tutorials/how-to-install-python-3-and-set-up-a-local-programming-environment-on-ubuntu-16-04). + +It's recommended to use a virtualenv: + +```bash + python3 -m venv venv + pip install pipelinewise-tap-snowflake +``` + +or + +```bash + python3 -m venv venv + . venv/bin/activate + pip install --upgrade pip + pip install . +``` + +### Configuration + +1. Create a `config.json` file with connection details to snowflake. + + ```json + { + "account": "rtxxxxx.eu-central-1", + "dbname": "database_name", + "user": "my_user", + "password": "password", + "warehouse": "my_virtual_warehouse", + "filter_dbs": "database_name", + "filter_schemas": "schema1,schema2" + } + ``` + +`filter_dbs` and `filter_schemas` are optional. + +2. Run it in discovery mode to generate a `properties.json` + +3. Edit the `properties.json` and select the streams to replicate + +4. Run the tap like any other singer compatible tap: + +``` + tap-snowflake --config config.json --properties properties.json --state state.json +``` + +### Discovery mode + +The tap can be invoked in discovery mode to find the available tables and +columns in the database: + +```bash +$ tap-snowflake --config config.json --discover + +``` + +A discovered catalog is output, with a JSON-schema description of each table. A +source table directly corresponds to a Singer stream. + +## Replication methods + +The two ways to replicate a given table are `FULL_TABLE` and `INCREMENTAL`. + +### Full Table + +Full-table replication extracts all data from the source table each time the tap +is invoked. + +### Incremental + +Incremental replication works in conjunction with a state file to only extract +new records each time the tap is invoked. This requires a replication key to be +specified in the table's metadata as well. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..36ca67f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +singer-python==5.3.1 +snowflake-connector-python==1.7.4 +backoff==1.3.2 +pendulum==1.2.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e472e58 --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +from setuptools import setup + +setup(name='pipelinewise-tap-snowflake', + version='1.0.0', + description='Singer.io tap for extracting data from Snowflake', + author="TransferWise", + url='https://github.com/transferwise/pipelinewise-tap-postgres', + classifiers=['Programming Language :: Python :: 3 :: Only'], + py_modules=['tap_snowflake'], + install_requires=[ + 'singer-python==5.3.1', + 'snowflake-connector-python==1.7.4', + 'backoff==1.3.2', + 'pendulum==1.2.0' + ], + entry_points=''' + [console_scripts] + tap-snowflake=tap_snowflake:main + ''', + packages=['tap_snowflake', 'tap_snowflake.sync_strategies'], +) diff --git a/tap_snowflake/__init__.py b/tap_snowflake/__init__.py new file mode 100644 index 0000000..fd3c897 --- /dev/null +++ b/tap_snowflake/__init__.py @@ -0,0 +1,473 @@ +#!/usr/bin/env python3 +# pylint: disable=missing-docstring,not-an-iterable,too-many-locals,too-many-arguments,too-many-branches,invalid-name,duplicate-code,too-many-statements + +import datetime +import collections +import itertools +from itertools import dropwhile +import copy + + +import singer +import singer.metrics as metrics +import singer.schema + +from singer import bookmarks +from singer import metadata +from singer import utils +from singer.schema import Schema +from singer.catalog import Catalog, CatalogEntry + +import tap_snowflake.sync_strategies.common as common +import tap_snowflake.sync_strategies.full_table as full_table +import tap_snowflake.sync_strategies.incremental as incremental + +from tap_snowflake.connection import SnowflakeConnection + + +LOGGER = singer.get_logger() + +Column = collections.namedtuple('Column', [ + "table_catalog", + "table_schema", + "table_name", + "column_name", + "data_type", + "character_maximum_length", + "numeric_precision", + "numeric_scale"]) + +REQUIRED_CONFIG_KEYS = [ + 'account', + 'dbname', + 'user', + 'password', + 'warehouse' +] + +# Snowflake data types +STRING_TYPES = set(['varchar', 'char', 'character', 'string', 'text']) +NUMBER_TYPES = set(['number', 'decimal', 'numeric']) +INTEGER_TYPES = set(['int', 'integer', 'bigint', 'smallint']) +FLOAT_TYPES = set(['float', 'float4', 'float8', 'real', 'double', 'double precision']) +DATETIME_TYPES = set(['datetime', 'timestamp', 'date', 'time', 'timestamp_ltz', 'timestamp_ntz', 'timestamp_tz']) +BINARY_TYPE = set(['binary', 'varbinary']) + + +def schema_for_column(c): + '''Returns the Schema object for the given Column.''' + data_type = c.data_type.lower() + + inclusion = 'available' + result = Schema(inclusion=inclusion) + + if data_type == 'boolean': + result.type = ['null', 'boolean'] + + elif data_type in INTEGER_TYPES: + result.type = ['null', 'number'] + + elif data_type in FLOAT_TYPES: + result.type = ['null', 'number'] + + elif data_type in NUMBER_TYPES: + result.type = ['null', 'number'] + + elif data_type in STRING_TYPES: + result.type = ['null', 'string'] + result.maxLength = c.character_maximum_length + + elif data_type in DATETIME_TYPES: + result.type = ['null', 'string'] + result.format = 'date-time' + + elif data_type in BINARY_TYPE: + result.type = ['null', 'string'] + + else: + result = Schema(None, + inclusion='unsupported', + description='Unsupported data type {}'.format(data_type)) + return result + + +def create_column_metadata(cols): + mdata = {} + mdata = metadata.write(mdata, (), 'selected-by-default', False) + for c in cols: + schema = schema_for_column(c) + mdata = metadata.write(mdata, + ('properties', c.column_name), + 'selected-by-default', + schema.inclusion != 'unsupported') + mdata = metadata.write(mdata, + ('properties', c.column_name), + 'sql-datatype', + c.data_type.lower()) + + return metadata.to_list(mdata) + + +def discover_catalog(snowflake_conn, config): + '''Returns a Catalog describing the structure of the database.''' + filter_dbs_config = config.get('filter_dbs') + filter_schemas_config = config.get('filter_schemas') + + if filter_dbs_config: + filter_dbs_clause = ",".join("LOWER('{}')".format(db) + for db in filter_dbs_config.split(",")) + + table_db_clause = "LOWER(t.table_catalog) IN ({})".format(filter_dbs_clause) + else: + table_db_clause = "1 = 1" + + if filter_schemas_config: + filter_schemas_clause = ",".join(["LOWER('{}')".format(schema) + for schema in filter_schemas_config.split(",")]) + + table_schema_clause = "LOWER(t.table_schema) IN ({})".format(filter_schemas_clause) + else: + table_schema_clause = "LOWER(t.table_schema) NOT IN ('information_schema')" + + table_info = {} + sql_columns = snowflake_conn.query(""" + SELECT t.table_catalog, + t.table_schema, + t.table_name, + t.table_type, + t.row_count, + c.column_name, + c.data_type, + c.character_maximum_length, + c.numeric_precision, + c.numeric_scale + FROM information_schema.tables t, + information_schema.columns c + WHERE t.table_catalog = c.table_catalog + AND t.table_schema = c.table_schema + AND t.table_name = c.table_name + AND {} + AND {} + """.format(table_db_clause, table_schema_clause)) + + columns = [] + for sql_col in sql_columns: + catalog = sql_col['TABLE_CATALOG'] + schema = sql_col['TABLE_SCHEMA'] + table_name = sql_col['TABLE_NAME'] + + if catalog not in table_info: + table_info[catalog] = {} + + if schema not in table_info[catalog]: + table_info[catalog][schema] = {} + + table_info[catalog][schema][table_name] = { + 'row_count': sql_col.get('ROW_COUNT'), + 'is_view': sql_col.get('TABLE_TYPE') == 'VIEW' + } + + columns.append(Column( + table_catalog=catalog, + table_schema=schema, + table_name=table_name, + column_name=sql_col['COLUMN_NAME'], + data_type=sql_col['DATA_TYPE'], + character_maximum_length=sql_col['CHARACTER_MAXIMUM_LENGTH'], + numeric_precision=sql_col['NUMERIC_PRECISION'], + numeric_scale=sql_col['NUMERIC_SCALE'] + )) + + entries = [] + for (k, cols) in itertools.groupby(columns, lambda c: (c.table_catalog, c.table_schema, c.table_name)): + cols = list(cols) + (table_catalog, table_schema, table_name) = k + schema = Schema(type='object', + properties={c.column_name: schema_for_column(c) for c in cols}) + md = create_column_metadata(cols) + md_map = metadata.to_map(md) + + md_map = metadata.write(md_map, (), 'database-name', table_catalog) + md_map = metadata.write(md_map, (), 'schema-name', table_schema) + + if ( + table_catalog in table_info and + table_schema in table_info[table_catalog] and + table_name in table_info[table_catalog][table_schema] + ): + row_count = table_info[table_catalog][table_schema][table_name].get('row_count') + is_view = table_info[table_catalog][table_schema][table_name]['is_view'] + + md_map = metadata.write(md_map, (), 'row-count', row_count) + md_map = metadata.write(md_map, (), 'is-view', is_view) + + entry = CatalogEntry( + table=table_name, + stream=table_name, + metadata=metadata.to_list(md_map), + tap_stream_id=common.generate_tap_stream_id(table_catalog, table_schema, table_name), + schema=schema) + + entries.append(entry) + + return Catalog(entries) + + +def do_discover(snowflake_conn, config): + discover_catalog(snowflake_conn, config).dump() + + +# TODO: Maybe put in a singer-db-utils library. +def desired_columns(selected, table_schema): + + '''Return the set of column names we need to include in the SELECT. + + selected - set of column names marked as selected in the input catalog + table_schema - the most recently discovered Schema for the table + ''' + all_columns = set() + available = set() + automatic = set() + unsupported = set() + + for column, column_schema in table_schema.properties.items(): + all_columns.add(column) + inclusion = column_schema.inclusion + if inclusion == 'automatic': + automatic.add(column) + elif inclusion == 'available': + available.add(column) + elif inclusion == 'unsupported': + unsupported.add(column) + else: + raise Exception('Unknown inclusion ' + inclusion) + + selected_but_unsupported = selected.intersection(unsupported) + if selected_but_unsupported: + LOGGER.warning( + 'Columns %s were selected but are not supported. Skipping them.', + selected_but_unsupported) + + selected_but_nonexistent = selected.difference(all_columns) + if selected_but_nonexistent: + LOGGER.warning( + 'Columns %s were selected but do not exist.', + selected_but_nonexistent) + + not_selected_but_automatic = automatic.difference(selected) + if not_selected_but_automatic: + LOGGER.warning( + 'Columns %s are primary keys but were not selected. Adding them.', + not_selected_but_automatic) + + return selected.intersection(available).union(automatic) + + +def resolve_catalog(discovered_catalog, streams_to_sync): + result = Catalog(streams=[]) + + # Iterate over the streams in the input catalog and match each one up + # with the same stream in the discovered catalog. + for catalog_entry in streams_to_sync: + catalog_metadata = metadata.to_map(catalog_entry.metadata) + replication_key = catalog_metadata.get((), {}).get('replication-key') + + discovered_table = discovered_catalog.get_stream(catalog_entry.tap_stream_id) + database_name = common.get_database_name(catalog_entry) + + if not discovered_table: + LOGGER.warning('Database %s table %s was selected but does not exist', + database_name, catalog_entry.table) + continue + + selected = {k for k, v in catalog_entry.schema.properties.items() + if common.property_is_selected(catalog_entry, k) or k == replication_key} + + # These are the columns we need to select + columns = desired_columns(selected, discovered_table.schema) + + result.streams.append(CatalogEntry( + tap_stream_id=catalog_entry.tap_stream_id, + metadata=catalog_entry.metadata, + stream=catalog_entry.tap_stream_id, + table=catalog_entry.table, + schema=Schema( + type='object', + properties={col: discovered_table.schema.properties[col] + for col in columns} + ) + )) + + return result + + +def get_streams(snowflake_conn, catalog, config, state): + '''Returns the Catalog of data we're going to sync for all SELECT-based + streams (i.e. INCREMENTAL and FULL_TABLE that require a historical + sync). + + Using the Catalog provided from the input file, this function will return a + Catalog representing exactly which tables and columns that will be emitted + by SELECT-based syncs. This is achieved by comparing the input Catalog to a + freshly discovered Catalog to determine the resulting Catalog. + + The resulting Catalog will include the following any streams marked as + "selected" that currently exist in the database. Columns marked as "selected" + and those labled "automatic" (e.g. primary keys and replication keys) will be + included. Streams will be prioritized in the following order: + 1. currently_syncing if it is SELECT-based + 2. any streams that do not have state + 3. any streams that do not have a replication method of LOG_BASED + + ''' + discovered = discover_catalog(snowflake_conn, config) + + # Filter catalog to include only selected streams + selected_streams = list(filter(lambda s: common.stream_is_selected(s), catalog.streams)) + streams_with_state = [] + streams_without_state = [] + + for stream in selected_streams: + stream_state = state.get('bookmarks', {}).get(stream.tap_stream_id) + + if not stream_state: + streams_without_state.append(stream) + else: + streams_with_state.append(stream) + + # If the state says we were in the middle of processing a stream, skip + # to that stream. Then process streams without prior state and finally + # move onto streams with state (i.e. have been synced in the past) + currently_syncing = singer.get_currently_syncing(state) + + # prioritize streams that have not been processed + ordered_streams = streams_without_state + streams_with_state + + if currently_syncing: + currently_syncing_stream = list(filter( + lambda s: s.tap_stream_id == currently_syncing, streams_with_state)) + + non_currently_syncing_streams = list(filter(lambda s: s.tap_stream_id != currently_syncing, ordered_streams)) + + streams_to_sync = currently_syncing_stream + non_currently_syncing_streams + else: + # prioritize streams that have not been processed + streams_to_sync = ordered_streams + + return resolve_catalog(discovered, streams_to_sync) + + +def write_schema_message(catalog_entry, bookmark_properties=[]): + key_properties = common.get_key_properties(catalog_entry) + + singer.write_message(singer.SchemaMessage( + stream=catalog_entry.stream, + schema=catalog_entry.schema.to_dict(), + key_properties=key_properties, + bookmark_properties=bookmark_properties + )) + + +def do_sync_incremental(snowflake_conn, catalog_entry, state, columns): + LOGGER.info("Stream %s is using incremental replication", catalog_entry.stream) + + md_map = metadata.to_map(catalog_entry.metadata) + replication_key = md_map.get((), {}).get('replication-key') + + if not replication_key: + raise Exception("Cannot use INCREMENTAL replication for table ({}) without a replication key.".format(catalog_entry.stream)) + + write_schema_message(catalog_entry=catalog_entry, + bookmark_properties=[replication_key]) + + incremental.sync_table(snowflake_conn, catalog_entry, state, columns) + + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + +def do_sync_full_table(snowflake_conn, catalog_entry, state, columns): + LOGGER.info("Stream %s is using full table replication", catalog_entry.stream) + + write_schema_message(catalog_entry) + + stream_version = common.get_stream_version(catalog_entry.tap_stream_id, state) + + full_table.sync_table(snowflake_conn, catalog_entry, state, columns, stream_version) + + # Prefer initial_full_table_complete going forward + singer.clear_bookmark(state, catalog_entry.tap_stream_id, 'version') + + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'initial_full_table_complete', + True) + + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + +def sync_streams(snowflake_conn, catalog, config, state): + for catalog_entry in catalog.streams: + columns = list(catalog_entry.schema.properties.keys()) + + if not columns: + LOGGER.warning('There are no columns selected for stream %s, skipping it.', catalog_entry.stream) + continue + + state = singer.set_currently_syncing(state, catalog_entry.tap_stream_id) + + # Emit a state message to indicate that we've started this stream + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + md_map = metadata.to_map(catalog_entry.metadata) + + replication_method = md_map.get((), {}).get('replication-method') + + database_name = common.get_database_name(catalog_entry) + schema_name = common.get_schema_name(catalog_entry) + + with metrics.job_timer('sync_table') as timer: + timer.tags['database'] = database_name + timer.tags['table'] = catalog_entry.table + + LOGGER.info("Beginning to sync %s.%s.%s", database_name, schema_name, catalog_entry.table) + + if replication_method == 'INCREMENTAL': + do_sync_incremental(snowflake_conn, catalog_entry, state, columns) + elif replication_method == 'FULL_TABLE': + do_sync_full_table(snowflake_conn, catalog_entry, state, columns) + else: + raise Exception("only INCREMENTAL and FULL TABLE replication methods are supported") + + state = singer.set_currently_syncing(state, None) + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + +def do_sync(snowflake_conn, config, catalog, state): + catalog = get_streams(snowflake_conn, catalog, config, state) + sync_streams(snowflake_conn, catalog, config, state) + + +def main_impl(): + args = utils.parse_args(REQUIRED_CONFIG_KEYS) + + snowflake_conn = SnowflakeConnection(args.config) + + if args.discover: + do_discover(snowflake_conn, args.config) + elif args.catalog: + state = args.state or {} + do_sync(snowflake_conn, args.config, args.catalog, state) + elif args.properties: + catalog = Catalog.from_dict(args.properties) + state = args.state or {} + do_sync(snowflake_conn, args.config, catalog, state) + else: + LOGGER.info("No properties were selected") + + +def main(): + try: + main_impl() + except Exception as exc: + LOGGER.critical(exc) + raise exc diff --git a/tap_snowflake/connection.py b/tap_snowflake/connection.py new file mode 100644 index 0000000..4562f3d --- /dev/null +++ b/tap_snowflake/connection.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +import backoff +import singer +import snowflake.connector + +LOGGER = singer.get_logger() + + +def retry_pattern(): + return backoff.on_exception(backoff.expo, + (snowflake.connector.errors.OperationalError), + max_tries=5, + on_backoff=log_backoff_attempt, + factor=2) + + +def log_backoff_attempt(details): + LOGGER.info("Error detected communicating with Snowflake, triggering backoff: %d try", details.get("tries")) + + +def validate_config(config): + errors = [] + required_config_keys = [ + 'account', + 'dbname', + 'user', + 'password', + 'warehouse' + ] + + # Check if mandatory keys exist + for k in required_config_keys: + if not config.get(k, None): + errors.append("Required key is missing from config: [{}]".format(k)) + + return errors + + +class SnowflakeConnection: + def __init__(self, connection_config): + """ + connection_config: Snowflake connection details + """ + self.connection_config = connection_config + config_errors = validate_config(connection_config) + if len(config_errors) == 0: + self.connection_config = connection_config + else: + LOGGER.error("Invalid configuration:\n * {}".format('\n * '.join(config_errors))) + exit(1) + + + def open_connection(self): + return snowflake.connector.connect( + user=self.connection_config['user'], + password=self.connection_config['password'], + account=self.connection_config['account'], + database=self.connection_config['dbname'], + warehouse=self.connection_config['warehouse'], + insecure_mode=self.connection_config.get('insecure_mode', False) + # Use insecure mode to avoid "Failed to get OCSP response" warnings + #insecure_mode=True + ) + + + @retry_pattern() + def connect_with_backoff(self): + return self.open_connection() + + + def query(self, query, params=None): + LOGGER.info("SNOWFLAKE - Running query: {}".format(query)) + with self.connect_with_backoff() as connection: + with connection.cursor(snowflake.connector.DictCursor) as cur: + cur.execute( + query, + params + ) + + if cur.rowcount > 0: + return cur.fetchall() + + return [] + diff --git a/tap_snowflake/sync_strategies/__init__.py b/tap_snowflake/sync_strategies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tap_snowflake/sync_strategies/common.py b/tap_snowflake/sync_strategies/common.py new file mode 100644 index 0000000..b8d66ce --- /dev/null +++ b/tap_snowflake/sync_strategies/common.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# pylint: disable=too-many-arguments,duplicate-code,too-many-locals + +import copy +import datetime +import singer +import time + +import singer.metrics as metrics +from singer import metadata +from singer import utils + +LOGGER = singer.get_logger() + +def escape(string): + if '"' in string: + raise Exception("Can't escape identifier {} because it contains a backtick" + .format(string)) + return '"{}"'.format(string) + + +def generate_tap_stream_id(catalog_name, schema_name, table_name): + return catalog_name + '-' + schema_name + '-' + table_name + + +def get_stream_version(tap_stream_id, state): + stream_version = singer.get_bookmark(state, tap_stream_id, 'version') + + if stream_version is None: + stream_version = int(time.time() * 1000) + + return stream_version + + +def stream_is_selected(stream): + md_map = metadata.to_map(stream.metadata) + selected_md = metadata.get(md_map, (), 'selected') + + return selected_md + + +def property_is_selected(stream, property_name): + md_map = metadata.to_map(stream.metadata) + return singer.should_sync_field( + metadata.get(md_map, ('properties', property_name), 'inclusion'), + metadata.get(md_map, ('properties', property_name), 'selected'), + True) + + +def get_is_view(catalog_entry): + md_map = metadata.to_map(catalog_entry.metadata) + + return md_map.get((), {}).get('is-view') + + +def get_database_name(catalog_entry): + md_map = metadata.to_map(catalog_entry.metadata) + + return md_map.get((), {}).get('database-name') + + +def get_schema_name(catalog_entry): + md_map = metadata.to_map(catalog_entry.metadata) + + return md_map.get((), {}).get('schema-name') + + +def get_key_properties(catalog_entry): + catalog_metadata = metadata.to_map(catalog_entry.metadata) + stream_metadata = catalog_metadata.get((), {}) + + is_view = get_is_view(catalog_entry) + + if is_view: + key_properties = stream_metadata.get('view-key-properties', []) + else: + key_properties = stream_metadata.get('table-key-properties', []) + + return key_properties + + +def generate_select_sql(catalog_entry, columns): + database_name = get_database_name(catalog_entry) + schema_name = get_schema_name(catalog_entry) + escaped_db = escape(database_name) + escaped_schema = escape(schema_name) + escaped_table = escape(catalog_entry.table) + escaped_columns = [escape(c) for c in columns] + + select_sql = 'SELECT {} FROM {}.{}.{}'.format( + ','.join(escaped_columns), + escaped_db, + escaped_schema, + escaped_table) + + # escape percent signs + select_sql = select_sql.replace('%', '%%') + return select_sql + + +def row_to_singer_record(catalog_entry, version, row, columns, time_extracted): + row_to_persist = () + for idx, elem in enumerate(row): + property_type = catalog_entry.schema.properties[columns[idx]].type + if isinstance(elem, datetime.datetime): + row_to_persist += (elem.isoformat() + '+00:00',) + + elif isinstance(elem, datetime.date): + row_to_persist += (elem.isoformat() + 'T00:00:00+00:00',) + + elif isinstance(elem, datetime.timedelta): + epoch = datetime.datetime.utcfromtimestamp(0) + timedelta_from_epoch = epoch + elem + row_to_persist += (timedelta_from_epoch.isoformat() + '+00:00',) + + elif isinstance(elem, bytes): + # for BIT value, treat 0 as False and anything else as True + if 'boolean' in property_type: + boolean_representation = elem != b'\x00' + row_to_persist += (boolean_representation,) + else: + row_to_persist += (elem.hex(),) + + elif 'boolean' in property_type or property_type == 'boolean': + if elem is None: + boolean_representation = None + elif elem == 0: + boolean_representation = False + else: + boolean_representation = True + row_to_persist += (boolean_representation,) + + else: + row_to_persist += (elem,) + rec = dict(zip(columns, row_to_persist)) + + return singer.RecordMessage( + stream=catalog_entry.stream, + record=rec, + version=version, + time_extracted=time_extracted) + + +def whitelist_bookmark_keys(bookmark_key_set, tap_stream_id, state): + for bk in [non_whitelisted_bookmark_key + for non_whitelisted_bookmark_key + in state.get('bookmarks', {}).get(tap_stream_id, {}).keys() + if non_whitelisted_bookmark_key not in bookmark_key_set]: + singer.clear_bookmark(state, tap_stream_id, bk) + + +def sync_query(cursor, catalog_entry, state, select_sql, columns, stream_version, params): + replication_key = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key') + + time_extracted = utils.now() + + LOGGER.info("Running {}".format(select_sql)) + cursor.execute(select_sql, params) + + row = cursor.fetchone() + rows_saved = 0 + + database_name = get_database_name(catalog_entry) + + with metrics.record_counter(None) as counter: + counter.tags['database'] = database_name + counter.tags['table'] = catalog_entry.table + + while row: + counter.increment() + rows_saved += 1 + record_message = row_to_singer_record(catalog_entry, + stream_version, + row, + columns, + time_extracted) + singer.write_message(record_message) + + md_map = metadata.to_map(catalog_entry.metadata) + stream_metadata = md_map.get((), {}) + replication_method = stream_metadata.get('replication-method') + + if replication_method == 'FULL_TABLE': + key_properties = get_key_properties(catalog_entry) + + max_pk_values = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'max_pk_values') + + if max_pk_values: + last_pk_fetched = {k:v for k,v in record_message.record.items() + if k in key_properties} + + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'last_pk_fetched', + last_pk_fetched) + + elif replication_method == 'INCREMENTAL': + if replication_key is not None: + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key', + replication_key) + + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key_value', + record_message.record[replication_key]) + if rows_saved % 1000 == 0: + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + row = cursor.fetchone() + + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) diff --git a/tap_snowflake/sync_strategies/full_table.py b/tap_snowflake/sync_strategies/full_table.py new file mode 100644 index 0000000..bbc59fc --- /dev/null +++ b/tap_snowflake/sync_strategies/full_table.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# pylint: disable=duplicate-code,too-many-locals,simplifiable-if-expression + +import copy +import singer +from singer import metadata + +import tap_snowflake.sync_strategies.common as common +from tap_snowflake.connection import SnowflakeConnection + +LOGGER = singer.get_logger() + +BOOKMARK_KEYS = {'last_pk_fetched', 'max_pk_values', 'version', 'initial_full_table_complete'} + +def get_max_pk_values(cursor, catalog_entry): + database_name = common.get_database_name(catalog_entry) + escaped_db = common.escape(database_name) + escaped_table = common.escape(catalog_entry.table) + + key_properties = common.get_key_properties(catalog_entry) + escaped_columns = [common.escape(c) for c in key_properties] + + sql = """SELECT {} + FROM {}.{} + ORDER BY {} + LIMIT 1 + """ + + select_column_clause = ", ".join(escaped_columns) + order_column_clause = ", ".join([pk + " DESC" for pk in escaped_columns]) + + cursor.execute(sql.format(select_column_clause, + escaped_db, + escaped_table, + order_column_clause)) + result = cursor.fetchone() + + if result: + max_pk_values = dict(zip(key_properties, result)) + else: + max_pk_values = {} + + return max_pk_values + +def generate_pk_clause(catalog_entry, state): + key_properties = common.get_key_properties(catalog_entry) + escaped_columns = [common.escape(c) for c in key_properties] + + where_clause = " AND ".join([pk + " > `{}`" for pk in escaped_columns]) + order_by_clause = ", ".join(['`{}`, ' for pk in escaped_columns]) + + max_pk_values = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'max_pk_values') + + last_pk_fetched = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'last_pk_fetched') + + if last_pk_fetched: + pk_comparisons = ["({} > {} AND {} <= {})".format(common.escape(pk), + last_pk_fetched[pk], + common.escape(pk), + max_pk_values[pk]) + for pk in key_properties] + else: + pk_comparisons = ["{} <= {}".format(common.escape(pk), max_pk_values[pk]) + for pk in key_properties] + + sql = " WHERE {} ORDER BY {} ASC".format(" AND ".join(pk_comparisons), + ", ".join(escaped_columns)) + + return sql + + + +def sync_table(snowflake_conn, catalog_entry, state, columns, stream_version): + common.whitelist_bookmark_keys(BOOKMARK_KEYS, catalog_entry.tap_stream_id, state) + + bookmark = state.get('bookmarks', {}).get(catalog_entry.tap_stream_id, {}) + version_exists = True if 'version' in bookmark else False + + initial_full_table_complete = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'initial_full_table_complete') + + state_version = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'version') + + activate_version_message = singer.ActivateVersionMessage( + stream=catalog_entry.stream, + version=stream_version + ) + + # For the initial replication, emit an ACTIVATE_VERSION message + # at the beginning so the records show up right away. + if not initial_full_table_complete and not (version_exists and state_version is None): + singer.write_message(activate_version_message) + + with snowflake_conn.connect_with_backoff() as open_conn: + with open_conn.cursor() as cur: + select_sql = common.generate_select_sql(catalog_entry, columns) + params = {} + + common.sync_query(cur, + catalog_entry, + state, + select_sql, + columns, + stream_version, + params) + + # clear max pk value and last pk fetched upon successful sync + singer.clear_bookmark(state, catalog_entry.tap_stream_id, 'max_pk_values') + singer.clear_bookmark(state, catalog_entry.tap_stream_id, 'last_pk_fetched') + + singer.write_message(activate_version_message) diff --git a/tap_snowflake/sync_strategies/incremental.py b/tap_snowflake/sync_strategies/incremental.py new file mode 100644 index 0000000..f6c420e --- /dev/null +++ b/tap_snowflake/sync_strategies/incremental.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# pylint: disable=duplicate-code + +import pendulum +import singer +from singer import metadata + +from tap_snowflake.connection import SnowflakeConnection +import tap_snowflake.sync_strategies.common as common + +LOGGER = singer.get_logger() + +BOOKMARK_KEYS = {'replication_key', 'replication_key_value', 'version'} + +def sync_table(snowflake_conn, catalog_entry, state, columns): + common.whitelist_bookmark_keys(BOOKMARK_KEYS, catalog_entry.tap_stream_id, state) + + catalog_metadata = metadata.to_map(catalog_entry.metadata) + stream_metadata = catalog_metadata.get((), {}) + + replication_key_metadata = stream_metadata.get('replication-key') + replication_key_state = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key') + + replication_key_value = None + + if replication_key_metadata == replication_key_state: + replication_key_value = singer.get_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key_value') + else: + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'replication_key', + replication_key_metadata) + state = singer.clear_bookmark(state, catalog_entry.tap_stream_id, 'replication_key_value') + + stream_version = common.get_stream_version(catalog_entry.tap_stream_id, state) + state = singer.write_bookmark(state, + catalog_entry.tap_stream_id, + 'version', + stream_version) + + activate_version_message = singer.ActivateVersionMessage( + stream=catalog_entry.stream, + version=stream_version + ) + + singer.write_message(activate_version_message) + + select_sql = common.generate_select_sql(catalog_entry, columns) + params = {} + + with snowflake_conn.connect_with_backoff() as open_conn: + with open_conn.cursor() as cur: + select_sql = common.generate_select_sql(catalog_entry, columns) + params = {} + + if replication_key_value is not None: + if catalog_entry.schema.properties[replication_key_metadata].format == 'date-time': + replication_key_value = pendulum.parse(replication_key_value) + + select_sql += ' WHERE "{}" >= \'{}\' ORDER BY "{}" ASC'.format( + replication_key_metadata, + replication_key_value, + replication_key_metadata) + + elif replication_key_metadata is not None: + select_sql += ' ORDER BY "{}" ASC'.format(replication_key_metadata) + + common.sync_query(cur, + catalog_entry, + state, + select_sql, + columns, + stream_version, + params) diff --git a/tests/test_tap_snowflake.py b/tests/test_tap_snowflake.py new file mode 100644 index 0000000..06bc5f9 --- /dev/null +++ b/tests/test_tap_snowflake.py @@ -0,0 +1,162 @@ +import unittest +import singer + +import tap_snowflake + +from singer.schema import Schema + + +try: + import tests.utils as test_utils +except ImportError: + import utils as test_utils + +LOGGER = singer.get_logger() + +SCHEMA_NAME='tap_snowflake_test' + +SINGER_MESSAGES = [] + +def accumulate_singer_messages(message): + SINGER_MESSAGES.append(message) + +singer.write_message = accumulate_singer_messages + +class TestTypeMapping(unittest.TestCase): + + @classmethod + def setUpClass(cls): + snowflake_conn = test_utils.get_test_connection() + + with snowflake_conn.open_connection() as open_conn: + with open_conn.cursor() as cur: + cur.execute(''' + CREATE TABLE {}.test_type_mapping ( + c_pk INTEGER PRIMARY KEY, + c_decimal DECIMAL, + c_decimal_2 DECIMAL(11, 2), + c_smallint SMALLINT, + c_int INT, + c_bigint BIGINT, + c_float FLOAT, + c_double DOUBLE, + c_date DATE, + c_time TIME, + c_binary BINARY, + c_varbinary VARBINARY(16) + )'''.format(SCHEMA_NAME)) + + catalog = test_utils.discover_catalog(snowflake_conn) + cls.schema = catalog.streams[0].schema + cls.metadata = catalog.streams[0].metadata + + def get_metadata_for_column(self, colName): + return next(md for md in self.metadata if md['breadcrumb'] == ('properties', colName))['metadata'] + + def test_decimal(self): + self.assertEqual(self.schema.properties['C_DECIMAL'], + Schema(['null', 'number'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_DECIMAL'), + {'selected-by-default': True, + 'sql-datatype': 'number'}) + + def test_decimal_with_defined_scale_and_precision(self): + self.assertEqual(self.schema.properties['C_DECIMAL_2'], + Schema(['null', 'number'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_DECIMAL_2'), + {'selected-by-default': True, + 'sql-datatype': 'number'}) + + def test_smallint(self): + self.assertEqual(self.schema.properties['C_SMALLINT'], + Schema(['null', 'number'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_SMALLINT'), + {'selected-by-default': True, + 'sql-datatype': 'number'}) + + def test_int(self): + self.assertEqual(self.schema.properties['C_INT'], + Schema(['null', 'number'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_INT'), + {'selected-by-default': True, + 'sql-datatype': 'number'}) + + def test_bigint(self): + self.assertEqual(self.schema.properties['C_BIGINT'], + Schema(['null', 'number'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_BIGINT'), + {'selected-by-default': True, + 'sql-datatype': 'number'}) + + def test_float(self): + self.assertEqual(self.schema.properties['C_FLOAT'], + Schema(['null', 'number'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_FLOAT'), + {'selected-by-default': True, + 'sql-datatype': 'float'}) + + def test_double(self): + self.assertEqual(self.schema.properties['C_DOUBLE'], + Schema(['null', 'number'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_DOUBLE'), + {'selected-by-default': True, + 'sql-datatype': 'float'}) + + def test_date(self): + self.assertEqual(self.schema.properties['C_DATE'], + Schema(['null', 'string'], + format='date-time', + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_DATE'), + {'selected-by-default': True, + 'sql-datatype': 'date'}) + + def test_time(self): + self.assertEqual(self.schema.properties['C_TIME'], + Schema(['null', 'string'], + format='date-time', + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_TIME'), + {'selected-by-default': True, + 'sql-datatype': 'time'}) + + def test_binary(self): + self.assertEqual(self.schema.properties['C_BINARY'], + Schema(['null', 'string'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_BINARY'), + {'selected-by-default': True, + 'sql-datatype': 'binary'}) + + def test_varbinary(self): + self.assertEqual(self.schema.properties['C_VARBINARY'], + Schema(['null', 'string'], + inclusion='available')) + self.assertEqual(self.get_metadata_for_column('C_VARBINARY'), + {'selected-by-default': True, + 'sql-datatype': 'binary'}) + + +class TestSelectsAppropriateColumns(unittest.TestCase): + + def runTest(self): + selected_cols = set(['a', 'b', 'd']) + table_schema = Schema(type='object', + properties={ + 'a': Schema(None, inclusion='available'), + 'b': Schema(None, inclusion='unsupported'), + 'c': Schema(None, inclusion='automatic')}) + + got_cols = tap_snowflake.desired_columns(selected_cols, table_schema) + + self.assertEqual(got_cols, + set(['a', 'c']), + 'Keep automatic as well as selected, available columns.') + diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..c604f6b --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,68 @@ +import os +import singer +import snowflake.connector + +import tap_snowflake +import tap_snowflake.sync_strategies.common as common +from tap_snowflake.connection import SnowflakeConnection + +SCHEMA_NAME='tap_snowflake_test' + +def get_db_config(): + config = {} + config['account'] = os.environ.get('TAP_SNOWFLAKE_ACCOUNT') + config['dbname'] = os.environ.get('TAP_SNOWFLAKE_DBNAME') + config['user'] = os.environ.get('TAP_SNOWFLAKE_USER') + config['password'] = os.environ.get('TAP_SNOWFLAKE_PASSWORD') + config['warehouse'] = os.environ.get('TAP_SNOWFLAKE_WAREHOUSE') + + return config + + +def get_tap_config(): + config = {} + config['filter_dbs'] = os.environ.get('TAP_SNOWFLAKE_DBNAME') + config['filter_schemas'] = SCHEMA_NAME + + return config + + +def get_test_connection(): + db_config = get_db_config() + snowflake_conn = SnowflakeConnection(db_config) + + with snowflake_conn.open_connection() as open_conn: + with open_conn.cursor() as cur: + try: + cur.execute('DROP SCHEMA IF EXISTS {}'.format(SCHEMA_NAME)) + except: + pass + cur.execute('CREATE SCHEMA {}'.format(SCHEMA_NAME)) + + return snowflake_conn + + +def discover_catalog(connection): + tap_config = get_tap_config() + catalog = tap_snowflake.discover_catalog(connection, tap_config) + streams = [] + + for stream in catalog.streams: + streams.append(stream) + + catalog.streams = streams + + return catalog + + +def set_replication_method_and_key(stream, r_method, r_key): + new_md = singer.metadata.to_map(stream.metadata) + old_md = new_md.get(()) + if r_method: + old_md.update({'replication-method': r_method}) + + if r_key: + old_md.update({'replication-key': r_key}) + + stream.metadata = singer.metadata.to_list(new_md) + return stream