diff --git a/target_snowflake/__init__.py b/target_snowflake/__init__.py index dc83da2b..3a194466 100644 --- a/target_snowflake/__init__.py +++ b/target_snowflake/__init__.py @@ -90,14 +90,6 @@ def persist_lines(config, lines): row_count = {} stream_to_sync = {} batch_size_rows = config.get('batch_size_rows', 100000) - table_columns_cache = None - - # Cache the available schemas, tables and columns from snowflake if not disabled in config - # The cache will be used later use to avoid lot of small queries hitting snowflake - if not ('disable_table_cache' in config and config['disable_table_cache'] == True): - logger.info("Caching available catalog objects in snowflake...") - filter_schemas = get_schema_names_from_config(config) - table_columns_cache = DbSync(config).get_table_columns(filter_schemas=filter_schemas) # Loop over lines from stdin for line in lines: @@ -189,8 +181,18 @@ def persist_lines(config, lines): else: stream_to_sync[stream] = DbSync(config, o) - stream_to_sync[stream].create_schema_if_not_exists(table_columns_cache) - stream_to_sync[stream].sync_table(table_columns_cache) + try: + stream_to_sync[stream].create_schema_if_not_exists() + stream_to_sync[stream].sync_table() + except Exception as e: + logger.error(""" + Cannot sync table structure in Snowflake schema: {} . + Try to delete {}.COLUMNS table to reset information_schema cache. Maybe it's outdated. + """.format( + stream_to_sync[stream].schema_name, + stream_to_sync[stream].pipelinewise_schema.upper())) + raise e + row_count[stream] = 0 csv_files_to_load[stream] = NamedTemporaryFile(mode='w+b') elif t == 'ACTIVATE_VERSION': @@ -231,7 +233,15 @@ def flush_records(stream, records_to_load, row_count, db_sync): f.write(bytes(csv_line + '\n', 'UTF-8')) s3_key = db_sync.put_to_stage(csv_file, stream, row_count) - db_sync.load_csv(s3_key, row_count) + try: + db_sync.load_csv(s3_key, row_count) + except Exception as e: + logger.error(""" + Cannot load data from S3 into Snowflake schema: {} . + Try to delete {}.COLUMNS table to reset information_schema cache. Maybe it's outdated. + """.format(db_sync.schema_name, db_sync.pipelinewise_schema.upper())) + raise e + os.remove(csv_file) db_sync.delete_from_stage(s3_key) diff --git a/target_snowflake/db_sync.py b/target_snowflake/db_sync.py index 735f708d..d8c785a6 100644 --- a/target_snowflake/db_sync.py +++ b/target_snowflake/db_sync.py @@ -153,22 +153,28 @@ def flatten_record(d, parent_key=[], sep='__', level=0, max_level=0): def primary_column_names(stream_schema_message): return [safe_column_name(p) for p in stream_schema_message['key_properties']] -def stream_name_to_dict(stream_name): +def stream_name_to_dict(stream_name, separator='-'): + catalog_name = None schema_name = None table_name = stream_name # Schema and table name can be derived from stream if it's in - format - s = stream_name.split('-') - if len(s) > 1: + s = stream_name.split(separator) + if len(s) == 2: schema_name = s[0] - table_name = '_'.join(s[1:]) + table_name = s[1] + if len(s) > 2: + catalog_name = s[0] + schema_name = s[1] + table_name = '_'.join(s[2:]) return { + 'catalog_name': catalog_name, 'schema_name': schema_name, 'table_name': table_name } -# pylint: disable=too-many-public-methods +# pylint: disable=too-many-public-methods,too-many-instance-attributes class DbSync: def __init__(self, connection_config, stream_schema_message=None): """ @@ -191,14 +197,23 @@ def __init__(self, connection_config, stream_schema_message=None): """ self.connection_config = connection_config config_errors = validate_config(connection_config) - if len(config_errors) == 0: - self.connection_config = connection_config - else: + + # Exit if config has errors + if len(config_errors) > 0: logger.error("Invalid configuration:\n * {}".format('\n * '.join(config_errors))) exit(1) + # Internal pipelinewise schema derived from the stage object in the config + stage = stream_name_to_dict(self.connection_config['stage'], separator='.') + if stage['schema_name']: + self.pipelinewise_schema = stage['schema_name'] + else: + logger.error("The named external stage object in config has to use the . format.") + exit(1) + self.schema_name = None self.grantees = None + self.information_schema_columns = None if stream_schema_message is not None: # Define target schema name. # -------------------------- @@ -248,6 +263,11 @@ def __init__(self, connection_config, stream_schema_message=None): if config_schema_mapping and stream_schema_name in config_schema_mapping: self.grantees = config_schema_mapping[stream_schema_name].get('target_schema_select_permissions', self.grantees) + # Caching enabled: get the list of available columns from auto maintained cache table + if not ('disable_table_cache' in self.connection_config and self.connection_config['disable_table_cache'] == True): + logger.info("Getting catalog objects from information_schema cache table...") + self.information_schema_columns = self.get_table_columns(table_schema=self.schema_name, from_information_schema_cache_table=True) + self.stream_schema_message = stream_schema_message if stream_schema_message is not None: @@ -370,6 +390,36 @@ def delete_from_stage(self, s3_key): self.s3.delete_object(Bucket=bucket, Key=s3_key) + def cache_information_schema_columns(self, create_only=False): + """Information_schema_columns cache is a copy of snowflake INFORMATION_SCHAME.COLUMNS table to avoid the error of + 'Information schema query returned too much data. Please repeat query with more selective predicates.'. + + Snowflake gives the above error message when running multiple taps in parallel (approx. >10 taps) and + when these taps selecting from information_schema at the same time. To avoid this problem we maintain a + local copy of the INFORMATION_SCHAME.COLUMNS table and it's keep updating automatically whenever it's needed. + """ + + # Create empty columns cache table if not exists + self.query(""" + CREATE TABLE IF NOT EXISTS {}.columns (table_schema VARCHAR, table_name VARCHAR, column_name VARCHAR, data_type VARCHAR) + """.format(self.pipelinewise_schema)) + + if not create_only: + # Delete existing data about the current schema + self.query(""" + DELETE FROM {}.columns + WHERE LOWER(table_schema) = '{}' + """.format(self.pipelinewise_schema, self.schema_name.lower())) + + # Insert the latest data from information_schema into the cache table + self.query(""" + INSERT INTO {}.columns + SELECT table_schema, table_name, column_name, data_type + FROM information_schema.columns + WHERE LOWER(table_schema) = '{}' + """.format(self.pipelinewise_schema, self.schema_name.lower())) + + def load_csv(self, s3_key, count): stream_schema_message = self.stream_schema_message stream = stream_schema_message['stream'] @@ -482,13 +532,13 @@ def delete_rows(self, stream): logger.info("Deleting rows from '{}' table... {}".format(table, query)) logger.info("DELETE {}".format(len(self.query(query)))) - def create_schema_if_not_exists(self, table_columns_cache=None): + def create_schema_if_not_exists(self): schema_name = self.schema_name schema_rows = 0 - # table_columns_cache is an optional pre-collected list of available objects in snowflake - if table_columns_cache: - schema_rows = list(filter(lambda x: x['TABLE_SCHEMA'] == schema_name, table_columns_cache)) + # information_schema_columns is an optional pre-collected list of available objects in snowflake + if self.information_schema_columns: + schema_rows = list(filter(lambda x: x['TABLE_SCHEMA'] == schema_name, self.information_schema_columns)) # Query realtime if not pre-collected else: schema_rows = self.query( @@ -510,25 +560,37 @@ def get_tables(self, table_schema=None): "LOWER(table_schema)" if table_schema is None else "'{}'".format(table_schema.lower()) )) - def get_table_columns(self, table_schema=None, table_name=None, filter_schemas=None): + def get_table_columns(self, table_schema=None, table_name=None, filter_schemas=None, from_information_schema_cache_table=False): + if from_information_schema_cache_table: + self.cache_information_schema_columns(create_only=True) + + # Select columns sql = """SELECT LOWER(c.table_schema) table_schema, LOWER(c.table_name) table_name, c.column_name, c.data_type - FROM information_schema.columns c - WHERE 1=1""" + FROM {}.columns c + WHERE 1 = 1""".format("information_schema" if not from_information_schema_cache_table else self.pipelinewise_schema) if table_schema is not None: sql = sql + " AND LOWER(c.table_schema) = '" + table_schema.lower() + "'" if table_name is not None: sql = sql + " AND LOWER(c.table_name) = '" + table_name.lower() + "'" if filter_schemas is not None: sql = sql + " AND LOWER(c.table_schema) IN (" + ', '.join("'{}'".format(s).lower() for s in filter_schemas) + ")" - return self.query(sql) + table_columns = self.query(sql) - def update_columns(self, table_columns_cache=None): + # Refresh cached information_schema if no results + if from_information_schema_cache_table and len(table_columns) == 0: + self.cache_information_schema_columns() + table_columns = self.query(sql) + + # Get columns from cache or information_schema and return results + return table_columns + + def update_columns(self): stream_schema_message = self.stream_schema_message stream = stream_schema_message['stream'] table_name = self.table_name(stream, False, True) schema_name = self.schema_name columns = [] - if table_columns_cache: - columns = list(filter(lambda x: x['TABLE_SCHEMA'] == self.schema_name.lower() and x['TABLE_NAME'].lower() == table_name, table_columns_cache)) + if self.information_schema_columns: + columns = list(filter(lambda x: x['TABLE_SCHEMA'] == self.schema_name.lower() and x['TABLE_NAME'].lower() == table_name, self.information_schema_columns)) else: - columns = self.get_table_columns(schema_name, table_name) + columns = self.get_table_columns(table_schema=schema_name, table_name=table_name) columns_dict = {column['COLUMN_NAME'].lower(): column for column in columns} columns_to_add = [ @@ -570,6 +632,10 @@ def update_columns(self, table_columns_cache=None): self.version_column(column_name, stream) self.add_column(column, stream) + # Refresh columns cache if required + if self.information_schema_columns is not None and (len(columns_to_add) > 0 or len(columns_to_replace)): + self.cache_information_schema_columns() + def drop_column(self, column_name, stream): drop_column = "ALTER TABLE {} DROP COLUMN {}".format(self.table_name(stream, False), column_name) logger.info('Dropping column: {}'.format(drop_column)) @@ -585,15 +651,15 @@ def add_column(self, column, stream): logger.info('Adding column: {}'.format(add_column)) self.query(add_column) - def sync_table(self, table_columns_cache=None): + def sync_table(self): stream_schema_message = self.stream_schema_message stream = stream_schema_message['stream'] table_name = self.table_name(stream, False, True) table_name_with_schema = self.table_name(stream, False) found_tables = [] - if table_columns_cache: - found_tables = list(filter(lambda x: x['TABLE_SCHEMA'] == self.schema_name.lower() and x['TABLE_NAME'].lower() == table_name, table_columns_cache)) + if self.information_schema_columns: + found_tables = list(filter(lambda x: x['TABLE_SCHEMA'] == self.schema_name.lower() and x['TABLE_NAME'].lower() == table_name, self.information_schema_columns)) else: found_tables = [table for table in (self.get_tables(self.schema_name.lower())) if table['TABLE_NAME'].lower() == table_name] @@ -603,7 +669,11 @@ def sync_table(self, table_columns_cache=None): self.query(query) self.grant_privilege(self.schema_name, self.grantees, self.grant_select_on_all_tables_in_schema) + + # Refresh columns cache if required + if self.information_schema_columns is not None: + self.cache_information_schema_columns() else: logger.info("Table '{}' exists".format(table_name_with_schema)) - self.update_columns(table_columns_cache) + self.update_columns() diff --git a/tests/integration/test_target_snowflake.py b/tests/integration/test_target_snowflake.py index 8f84e4d7..3e3a087a 100644 --- a/tests/integration/test_target_snowflake.py +++ b/tests/integration/test_target_snowflake.py @@ -31,9 +31,15 @@ class TestIntegration(unittest.TestCase): def setUp(self): self.config = test_utils.get_test_config() snowflake = DbSync(self.config) + + # Drop target schema if self.config['default_target_schema']: snowflake.query("DROP SCHEMA IF EXISTS {}".format(self.config['default_target_schema'])) + # Drop pipelinewise schema with information_schema cache + if self.config['stage']: + snowflake.query("DROP TABLE IF EXISTS {}.columns".format(snowflake.pipelinewise_schema)) + def remove_metadata_columns_from_rows(self, rows): """Removes metadata columns from a list of rows""" @@ -395,3 +401,87 @@ def test_column_name_change(self): {'C_INT': 3, 'C_PK': 3, 'C_TIME': datetime.time(23, 0, 3), 'C_VARCHAR': '3', 'C_TIME_RENAMED': datetime.time(8, 15)}, {'C_INT': 4, 'C_PK': 4, 'C_TIME': None, 'C_VARCHAR': '4', 'C_TIME_RENAMED': datetime.time(23, 0, 3)} ]) + + + def test_information_schema_cache_create_and_update(self): + """Newly created and altered tables must be cached automatically for later use. + + Information_schema_columns cache is a copy of snowflake INFORMATION_SCHAME.COLUMNS table to avoid the error of + 'Information schema query returned too much data. Please repeat query with more selective predicates.'. + """ + tap_lines_before_column_name_change = test_utils.get_test_tap_lines('messages-with-three-streams.json') + tap_lines_after_column_name_change = test_utils.get_test_tap_lines('messages-with-three-streams-modified-column.json') + + # Load with default settings + target_snowflake.persist_lines(self.config, tap_lines_before_column_name_change) + target_snowflake.persist_lines(self.config, tap_lines_after_column_name_change) + + # Get data form information_schema cache table + snowflake = DbSync(self.config) + target_schema = self.config.get('default_target_schema', '') + information_schema_cache = snowflake.query("SELECT * FROM {}.columns ORDER BY table_schema, table_name, column_name".format(snowflake.pipelinewise_schema)) + + # Get the previous column name from information schema in test_table_two + previous_column_name = snowflake.query(""" + SELECT column_name + FROM information_schema.columns + WHERE table_catalog = '{}' + AND table_schema = '{}' + AND table_name = 'TEST_TABLE_TWO' + AND ordinal_position = 1 + """.format( + self.config.get('dbname', '').upper(), + target_schema.upper()))[0]["COLUMN_NAME"] + + # Every column has to be in the cached information_schema with the latest versions + self.assertEqual( + information_schema_cache, + [ + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_ONE', 'COLUMN_NAME': 'C_INT', 'DATA_TYPE': 'NUMBER'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_ONE', 'COLUMN_NAME': 'C_PK', 'DATA_TYPE': 'NUMBER'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_ONE', 'COLUMN_NAME': 'C_VARCHAR', 'DATA_TYPE': 'TEXT'}, + + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_THREE', 'COLUMN_NAME': 'C_INT', 'DATA_TYPE': 'NUMBER'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_THREE', 'COLUMN_NAME': 'C_PK', 'DATA_TYPE': 'NUMBER'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_THREE', 'COLUMN_NAME': 'C_TIME', 'DATA_TYPE': 'TIME'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_THREE', 'COLUMN_NAME': 'C_TIME_RENAMED', 'DATA_TYPE':'TIME'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_THREE', 'COLUMN_NAME': 'C_VARCHAR', 'DATA_TYPE': 'TEXT'}, + + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_TWO', 'COLUMN_NAME': 'C_DATE', 'DATA_TYPE': 'TEXT'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_TWO', 'COLUMN_NAME': previous_column_name, 'DATA_TYPE': 'TIMESTAMP_NTZ'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_TWO', 'COLUMN_NAME': 'C_INT', 'DATA_TYPE': 'NUMBER'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_TWO', 'COLUMN_NAME': 'C_PK', 'DATA_TYPE': 'NUMBER'}, + {'TABLE_SCHEMA': 'LOCAL_DEV1', 'TABLE_NAME': 'TEST_TABLE_TWO', 'COLUMN_NAME': 'C_VARCHAR', 'DATA_TYPE': 'TEXT'} + ]) + + + def test_information_schema_cache_outdate(self): + """If informations schema cache is not up to date then it should fail""" + tap_lines_with_multi_streams = test_utils.get_test_tap_lines('messages-with-three-streams.json') + + # 1) Simulate an out of data cache: + # Table is in cache but not exists in database + snowflake = DbSync(self.config) + snowflake.query(""" + CREATE TABLE IF NOT EXISTS {}.columns (table_schema VARCHAR, table_name VARCHAR, column_name VARCHAR, data_type VARCHAR) + """.format(snowflake.pipelinewise_schema)) + snowflake.query(""" + INSERT INTO {}.columns (table_schema, table_name, column_name, data_type) + SELECT 'LOCAL_DEV1', 'TEST_TABLE_ONE', 'DUMMY_COLUMN_1', 'TEXT' UNION + SELECT 'LOCAL_DEV1', 'TEST_TABLE_ONE', 'DUMMY_COLUMN_2', 'TEXT' UNION + SELECT 'LOCAL_DEV1', 'TEST_TABLE_TWO', 'DUMMY_COLUMN_3', 'TEXT' + """.format(snowflake.pipelinewise_schema)) + + # Loading into an outdated information_schema cache should fail with table not exists + with self.assertRaises(Exception): + target_snowflake.persist_lines(self.config, tap_lines_with_multi_streams) + + # 2) Simulate an out of data cache: + # Table is in cache structure is not in sync with the actual table in the database + snowflake.query("CREATE SCHEMA IF NOT EXISTS local_dev1") + snowflake.query("CREATE OR REPLACE TABLE local_dev1.test_table_one (C_PK NUMBER, C_INT NUMBER, C_VARCHAR TEXT)") + + # Loading into an outdated information_schema cache should fail with columns exists + # It should try adding the new column based on the values in cache but the column already exists + with self.assertRaises(Exception): + target_snowflake.persist_lines(self.config, tap_lines_with_multi_streams) \ No newline at end of file diff --git a/tests/unit/test_unit.py b/tests/unit/test_unit.py index 2838984f..b2c248db 100644 --- a/tests/unit/test_unit.py +++ b/tests/unit/test_unit.py @@ -89,6 +89,39 @@ def test_column_type_mapping(self): self.assertEquals(mapper(json_arr) , 'variant') + def test_stream_name_to_dict(self): + """Test identifying catalog, schema and table names from fully qualified stream and table names""" + # Singer stream name format (Default '-' separator) + self.assertEquals( + target_snowflake.db_sync.stream_name_to_dict('my_table'), + {"catalog_name": None, "schema_name": None, "table_name": "my_table"}) + + # Singer stream name format (Default '-' separator) + self.assertEquals( + target_snowflake.db_sync.stream_name_to_dict('my_schema-my_table'), + {"catalog_name": None, "schema_name": "my_schema", "table_name": "my_table"}) + + # Singer stream name format (Default '-' separator) + self.assertEquals( + target_snowflake.db_sync.stream_name_to_dict('my_catalog-my_schema-my_table'), + {"catalog_name": "my_catalog", "schema_name": "my_schema", "table_name": "my_table"}) + + # Snowflake table format (Custom '.' separator) + self.assertEquals( + target_snowflake.db_sync.stream_name_to_dict('my_table', separator='.'), + {"catalog_name": None, "schema_name": None, "table_name": "my_table"}) + + # Snowflake table format (Custom '.' separator) + self.assertEquals( + target_snowflake.db_sync.stream_name_to_dict('my_schema.my_table', separator='.'), + {"catalog_name": None, "schema_name": "my_schema", "table_name": "my_table"}) + + # Snowflake table format (Custom '.' separator) + self.assertEquals( + target_snowflake.db_sync.stream_name_to_dict('my_catalog.my_schema.my_table', separator='.'), + {"catalog_name": "my_catalog", "schema_name": "my_schema", "table_name": "my_table"}) + + def test_flatten_schema(self): """Test flattening of SCHEMA messages""" flatten_schema = target_snowflake.db_sync.flatten_schema