From 69a1d8fd66ad1bbfc3bb39fffdbbf6f53659c09e Mon Sep 17 00:00:00 2001 From: Peter Kosztolanyi Date: Tue, 17 Mar 2020 18:45:33 +0000 Subject: [PATCH] [AP-591] Use SHOW SCHEMAS|TABLES|COLUMNS instead of INFORMATION_SCHEMA (#14) --- tap_snowflake/__init__.py | 161 +++++++++++++++++++----- tap_snowflake/connection.py | 35 ++++-- tests/integration/test_tap_snowflake.py | 45 +++---- 3 files changed, 181 insertions(+), 60 deletions(-) diff --git a/tap_snowflake/__init__.py b/tap_snowflake/__init__.py index 1f74790..89f446f 100644 --- a/tap_snowflake/__init__.py +++ b/tap_snowflake/__init__.py @@ -4,11 +4,14 @@ import collections import copy import itertools +import re +import sys import logging import singer import singer.metrics as metrics import singer.schema +import snowflake.connector from singer import metadata from singer import utils from singer.catalog import Catalog, CatalogEntry @@ -21,6 +24,11 @@ LOGGER = singer.get_logger('tap_snowflake') +# Max number of rows that a SHOW SCHEMAS|TABLES|COLUMNS can return. +# If more than this number of rows returned then tap-snowflake will raise TooManyRecordsException +SHOW_COMMAND_MAX_ROWS = 9999 + + # Tone down snowflake connector logs noise logging.getLogger('snowflake.connector').setLevel(logging.WARNING) @@ -110,46 +118,140 @@ def create_column_metadata(cols): return metadata.to_list(mdata) +def get_databases(snowflake_conn): + """Get snowflake databases""" + databases = snowflake_conn.query('SHOW DATABASES', max_records=SHOW_COMMAND_MAX_ROWS) + + # Return only the name of databases as a list + return [db['name'] for db in databases] + + +def get_schemas(snowflake_conn, database): + """Get schemas of a database""" + schemas = [] + try: + schemas = snowflake_conn.query(f'SHOW SCHEMAS IN DATABASE {database}', max_records=SHOW_COMMAND_MAX_ROWS) + + # Get only the name of schemas as a list + schemas = [schema['name'] for schema in schemas] + + # Catch exception when schema not exists and SHOW SCHEMAS throws a ProgrammingError + # Regexp to extract snowflake error code and message from the exception message + # Do nothing if schema not exists + except snowflake.connector.errors.ProgrammingError as exc: + # pylint: disable=anomalous-backslash-in-string + if re.match('.*\(02000\):.*\n.*does not exist.*', str(sys.exc_info()[1])): + pass + else: + raise exc + + return schemas + + +def get_table_columns(snowflake_conn, database, table_schemas=None, table_name=None): + """Get column definitions for every table in specific schemas(s) + + It's using SHOW commands instead of INFORMATION_SCHEMA views bucause information_schemas views are slow + and can cause unexpected exception of: + Information schema query returned too much data. Please repeat query with more selective predicates. + """ + table_columns = [] + if table_schemas or table_name: + for schema in table_schemas: + queries = [] + + LOGGER.info('Getting schema information for %s.%s...', database, schema) + + # Get column data types by SHOW commands + show_tables = f'SHOW TABLES IN SCHEMA {database}.{schema}' + show_views = f'SHOW TABLES IN SCHEMA {database}.{schema}' + show_columns = f'SHOW COLUMNS IN SCHEMA {database}.{schema}' + + # Convert output of SHOW commands to tables and use SQL joins to get every required information + select = f""" + WITH + show_tables AS (SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID(-3)))), + show_views AS (SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID(-2)))), + show_columns AS (SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID(-1)))) + SELECT show_columns."database_name" AS table_catalog + ,show_columns."schema_name" AS table_schema + ,show_columns."table_name" AS table_name + ,CASE + WHEN show_tables."name" IS NOT NULL THEN 'BASE TABLE' + ELSE 'VIEW' + END table_type + ,show_tables."rows" AS row_count + ,show_columns."column_name" AS column_name + -- ---------------------------------------------------------------------------------------- + -- Character and numeric columns display their generic data type rather than their defined + -- data type (i.e. TEXT for all character types, FIXED for all fixed-point numeric types, + -- and REAL for all floating-point numeric types). + -- + -- Further info at https://docs.snowflake.net/manuals/sql-reference/sql/show-columns.html + -- ---------------------------------------------------------------------------------------- + ,CASE PARSE_JSON(show_columns."data_type"):type::varchar + WHEN 'FIXED' THEN 'NUMBER' + WHEN 'REAL' THEN 'FLOAT' + ELSE PARSE_JSON("data_type"):type::varchar + END data_type + ,PARSE_JSON(show_columns."data_type"):length::number AS character_maximum_length + ,PARSE_JSON(show_columns."data_type"):precision::number AS numeric_precision + ,PARSE_JSON(show_columns."data_type"):scale::number AS numeric_scale + FROM show_columns + LEFT JOIN show_tables + ON show_tables."database_name" = show_columns."database_name" + AND show_tables."schema_name" = show_columns."schema_name" + AND show_tables."name" = show_columns."table_name" + LEFT JOIN show_views + ON show_views."database_name" = show_columns."database_name" + AND show_views."schema_name" = show_columns."schema_name" + AND show_views."name" = show_columns."table_name" + """ + queries.extend([show_tables, show_views, show_columns, select]) + + # Run everything in one transaction + try: + columns = snowflake_conn.query(queries, max_records=SHOW_COMMAND_MAX_ROWS) + table_columns.extend(columns) + + # Catch exception when schema not exists and SHOW COLUMNS throws a ProgrammingError + # Regexp to extract snowflake error code and message from the exception message + # Do nothing if schema not exists + except snowflake.connector.errors.ProgrammingError as exc: + # pylint: disable=anomalous-backslash-in-string + if re.match('.*\(02000\):.*\n.*does not exist.*', str(sys.exc_info()[1])): + pass + else: + raise exc + + return table_columns + + def discover_catalog(snowflake_conn, config): """Returns a Catalog describing the structure of the database.""" filter_dbs_config = config.get('filter_dbs') filter_schemas_config = config.get('filter_schemas') + databases = [] + schemas = [] + # Get databases + sql_columns = [] if filter_dbs_config: - filter_dbs_clause = ','.join(f"LOWER('{db}')" for db in filter_dbs_config.split(',')) - - table_db_clause = f'LOWER(t.table_catalog) IN ({filter_dbs_clause})' + databases = filter_dbs_config.split(',') else: - table_db_clause = '1 = 1' + databases = get_databases(snowflake_conn) + for database in databases: - if filter_schemas_config: - filter_schemas_clause = ','.join([f"LOWER('{schema}')" for schema in filter_schemas_config.split(',')]) + # Get schemas + if filter_schemas_config: + schemas = filter_schemas_config.split(',') + else: + schemas = get_schemas(snowflake_conn, database) - table_schema_clause = f'LOWER(t.table_schema) IN ({filter_schemas_clause})' - else: - table_schema_clause = "LOWER(t.table_schema) NOT IN ('information_schema')" + table_columns = get_table_columns(snowflake_conn, database, schemas) + sql_columns.extend(table_columns) table_info = {} - sql_columns = snowflake_conn.query(""" - SELECT t.table_catalog, - t.table_schema, - t.table_name, - t.table_type, - t.row_count, - c.column_name, - c.data_type, - c.character_maximum_length, - c.numeric_precision, - c.numeric_scale - FROM information_schema.tables t, - information_schema.columns c - WHERE t.table_catalog = c.table_catalog - AND t.table_schema = c.table_schema - AND t.table_name = c.table_name - AND {} - AND {} - """.format(table_db_clause, table_schema_clause)) - columns = [] for sql_col in sql_columns: catalog = sql_col['TABLE_CATALOG'] @@ -217,6 +319,7 @@ def discover_catalog(snowflake_conn, config): def do_discover(snowflake_conn, config): discover_catalog(snowflake_conn, config).dump() + # pylint: disable=fixme # TODO: Maybe put in a singer-db-utils library. def desired_columns(selected, table_schema): diff --git a/tap_snowflake/connection.py b/tap_snowflake/connection.py index 638cef4..eb52e80 100644 --- a/tap_snowflake/connection.py +++ b/tap_snowflake/connection.py @@ -8,6 +8,10 @@ LOGGER = singer.get_logger('tap_snowflake') +class TooManyRecordsException(Exception): + """Exception to raise when query returns more records than max_records""" + + def retry_pattern(): """Retry pattern decorator used when connecting to snowflake """ @@ -79,17 +83,30 @@ def connect_with_backoff(self): return self.open_connection() - def query(self, query, params=None): + def query(self, query, params=None, max_records=0): """Run a query in snowflake""" - LOGGER.info('SNOWFLAKE - Running query: %s', query) + result = [] with self.connect_with_backoff() as connection: with connection.cursor(snowflake.connector.DictCursor) as cur: - cur.execute( - query, - params - ) + queries = [] + + # Run every query in one transaction if query is a list of SQL + if isinstance(query, list): + queries.append('START TRANSACTION') + queries.extend(query) + else: + queries = [query] + + for sql in queries: + LOGGER.debug('SNOWFLAKE - Running query: %s', sql) + cur.execute(sql, params) + + # Raise exception if returned rows greater than max allowed records + if 0 < max_records < cur.rowcount: + raise TooManyRecordsException( + f'Query returned too many records. This query can return max {max_records} records') - if cur.rowcount > 0: - return cur.fetchall() + if cur.rowcount > 0: + result = cur.fetchall() - return [] + return result diff --git a/tests/integration/test_tap_snowflake.py b/tests/integration/test_tap_snowflake.py index 4281db3..4f59a50 100644 --- a/tests/integration/test_tap_snowflake.py +++ b/tests/integration/test_tap_snowflake.py @@ -7,7 +7,6 @@ from singer.schema import Schema - try: import tests.utils as test_utils except ImportError: @@ -15,15 +14,18 @@ LOGGER = singer.get_logger('tap_snowflake_tests') -SCHEMA_NAME='tap_snowflake_test' +SCHEMA_NAME = 'tap_snowflake_test' SINGER_MESSAGES = [] + def accumulate_singer_messages(message): SINGER_MESSAGES.append(message) + singer.write_message = accumulate_singer_messages + class TestTypeMapping(unittest.TestCase): @classmethod @@ -181,10 +183,10 @@ def test_row_to_singer_record(self): # Convert the exported data to singer JSON record_message = common.row_to_singer_record(catalog_entry=catalog_entry, - version=1, - row=row, - columns=columns, - time_extracted=singer.utils.now()) + version=1, + row=row, + columns=columns, + time_extracted=singer.utils.now()) # Convert to formatted JSON formatted_record = singer.messages.format_message(record_message) @@ -193,21 +195,21 @@ def test_row_to_singer_record(self): self.assertEquals(json.loads(formatted_record)['type'], 'RECORD') self.assertEquals(json.loads(formatted_record)['stream'], 'TEST_TYPE_MAPPING') self.assertEquals(json.loads(formatted_record)['record'], - { - 'C_PK': 1, - 'C_DECIMAL': 12345, - 'C_DECIMAL_2': 123456789.12, - 'C_SMALLINT': 123, - 'C_INT': 12345, - 'C_BIGINT': 1234567890, - 'C_FLOAT': 123.123, - 'C_DOUBLE': 123.123, - 'C_DATE': '2019-08-01T00:00:00+00:00', - 'C_DATETIME': '2019-08-01T17:23:59+00:00', - 'C_TIME': '17:23:59', - 'C_BINARY': '62696E617279', - 'C_VARBINARY': '76617262696E617279' - }) + { + 'C_PK': 1, + 'C_DECIMAL': 12345, + 'C_DECIMAL_2': 123456789.12, + 'C_SMALLINT': 123, + 'C_INT': 12345, + 'C_BIGINT': 1234567890, + 'C_FLOAT': 123.123, + 'C_DOUBLE': 123.123, + 'C_DATE': '2019-08-01T00:00:00+00:00', + 'C_DATETIME': '2019-08-01T17:23:59+00:00', + 'C_TIME': '17:23:59', + 'C_BINARY': '62696E617279', + 'C_VARBINARY': '76617262696E617279' + }) class TestSelectsAppropriateColumns(unittest.TestCase): @@ -225,4 +227,3 @@ def runTest(self): self.assertEqual(got_cols, set(['a', 'c']), 'Keep automatic as well as selected, available columns.') -