diff --git a/setup.py b/setup.py index 31af992a..27bea0c7 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ long_description = f.read() setup(name="pipelinewise-target-snowflake", - version="1.0.5", + version="1.0.6", description="Singer.io target for loading data to Snowflake - PipelineWise compatible", long_description=long_description, long_description_content_type='text/markdown', diff --git a/target_snowflake/__init__.py b/target_snowflake/__init__.py index 3a194466..2c09a69a 100644 --- a/target_snowflake/__init__.py +++ b/target_snowflake/__init__.py @@ -79,8 +79,20 @@ def get_schema_names_from_config(config): return schema_names +def load_information_schema_cache(config): + information_schema_cache = [] + if not ('disable_table_cache' in config and config['disable_table_cache'] == True): + logger.info("Getting catalog objects from information_schema cache table...") + + db = DbSync(config) + information_schema_cache = db.get_table_columns( + table_schemas=get_schema_names_from_config(config), + from_information_schema_cache_table=True) + + return information_schema_cache + # pylint: disable=too-many-locals,too-many-branches,too-many-statements -def persist_lines(config, lines): +def persist_lines(config, lines, information_schema_cache=None): state = None schemas = {} key_properties = {} @@ -177,9 +189,9 @@ def persist_lines(config, lines): key_properties[stream] = o['key_properties'] if config.get('add_metadata_columns') or config.get('hard_delete'): - stream_to_sync[stream] = DbSync(config, add_metadata_columns_to_schema(o)) + stream_to_sync[stream] = DbSync(config, add_metadata_columns_to_schema(o), information_schema_cache) else: - stream_to_sync[stream] = DbSync(config, o) + stream_to_sync[stream] = DbSync(config, o, information_schema_cache) try: stream_to_sync[stream].create_schema_if_not_exists() @@ -257,8 +269,12 @@ def main(): else: config = {} + # Init information schema cache + information_schema_cache = load_information_schema_cache(config) + + # Consume singer messages input = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') - state = persist_lines(config, input) + state = persist_lines(config, input, information_schema_cache) emit_state(state) logger.debug("Exiting normally") diff --git a/target_snowflake/db_sync.py b/target_snowflake/db_sync.py index d8c785a6..7aba1fdf 100644 --- a/target_snowflake/db_sync.py +++ b/target_snowflake/db_sync.py @@ -176,7 +176,7 @@ def stream_name_to_dict(stream_name, separator='-'): # pylint: disable=too-many-public-methods,too-many-instance-attributes class DbSync: - def __init__(self, connection_config, stream_schema_message=None): + def __init__(self, connection_config, stream_schema_message=None, information_schema_cache=None): """ connection_config: Snowflake connection details @@ -196,6 +196,10 @@ def __init__(self, connection_config, stream_schema_message=None): purposes. """ self.connection_config = connection_config + self.stream_schema_message = stream_schema_message + self.information_schema_columns = information_schema_cache + + # Validate connection configuration config_errors = validate_config(connection_config) # Exit if config has errors @@ -213,8 +217,9 @@ def __init__(self, connection_config, stream_schema_message=None): self.schema_name = None self.grantees = None - self.information_schema_columns = None - if stream_schema_message is not None: + + # Init stream schema + if self.stream_schema_message is not None: # Define target schema name. # -------------------------- # Target schema name can be defined in multiple ways: @@ -263,14 +268,6 @@ def __init__(self, connection_config, stream_schema_message=None): if config_schema_mapping and stream_schema_name in config_schema_mapping: self.grantees = config_schema_mapping[stream_schema_name].get('target_schema_select_permissions', self.grantees) - # Caching enabled: get the list of available columns from auto maintained cache table - if not ('disable_table_cache' in self.connection_config and self.connection_config['disable_table_cache'] == True): - logger.info("Getting catalog objects from information_schema cache table...") - self.information_schema_columns = self.get_table_columns(table_schema=self.schema_name, from_information_schema_cache_table=True) - - self.stream_schema_message = stream_schema_message - - if stream_schema_message is not None: self.data_flattening_max_level = self.connection_config.get('data_flattening_max_level', 0) self.flatten_schema = flatten_schema(stream_schema_message['schema'], max_level=self.data_flattening_max_level) @@ -390,7 +387,7 @@ def delete_from_stage(self, s3_key): self.s3.delete_object(Bucket=bucket, Key=s3_key) - def cache_information_schema_columns(self, create_only=False): + def cache_information_schema_columns(self, table_schemas=[], create_only=False): """Information_schema_columns cache is a copy of snowflake INFORMATION_SCHAME.COLUMNS table to avoid the error of 'Information schema query returned too much data. Please repeat query with more selective predicates.'. @@ -400,24 +397,29 @@ def cache_information_schema_columns(self, create_only=False): """ # Create empty columns cache table if not exists + self.query(""" + CREATE SCHEMA IF NOT EXISTS {} + """.format(self.pipelinewise_schema)) self.query(""" CREATE TABLE IF NOT EXISTS {}.columns (table_schema VARCHAR, table_name VARCHAR, column_name VARCHAR, data_type VARCHAR) """.format(self.pipelinewise_schema)) - if not create_only: + if not create_only and table_schemas: # Delete existing data about the current schema - self.query(""" + sql = """ DELETE FROM {}.columns - WHERE LOWER(table_schema) = '{}' - """.format(self.pipelinewise_schema, self.schema_name.lower())) + """.format(self.pipelinewise_schema) + sql = sql + " WHERE LOWER(table_schema) IN ({})".format(', '.join("'{}'".format(s).lower() for s in table_schemas)) + self.query(sql) # Insert the latest data from information_schema into the cache table - self.query(""" + sql = """ INSERT INTO {}.columns SELECT table_schema, table_name, column_name, data_type FROM information_schema.columns - WHERE LOWER(table_schema) = '{}' - """.format(self.pipelinewise_schema, self.schema_name.lower())) + """.format(self.pipelinewise_schema) + sql = sql + " WHERE LOWER(table_schema) IN ({})".format(', '.join("'{}'".format(s).lower() for s in table_schemas)) + self.query(sql) def load_csv(self, s3_key, count): @@ -560,22 +562,27 @@ def get_tables(self, table_schema=None): "LOWER(table_schema)" if table_schema is None else "'{}'".format(table_schema.lower()) )) - def get_table_columns(self, table_schema=None, table_name=None, filter_schemas=None, from_information_schema_cache_table=False): + def get_table_columns(self, table_schemas=[], table_name=None, from_information_schema_cache_table=False): if from_information_schema_cache_table: self.cache_information_schema_columns(create_only=True) # Select columns - sql = """SELECT LOWER(c.table_schema) table_schema, LOWER(c.table_name) table_name, c.column_name, c.data_type - FROM {}.columns c - WHERE 1 = 1""".format("information_schema" if not from_information_schema_cache_table else self.pipelinewise_schema) - if table_schema is not None: sql = sql + " AND LOWER(c.table_schema) = '" + table_schema.lower() + "'" - if table_name is not None: sql = sql + " AND LOWER(c.table_name) = '" + table_name.lower() + "'" - if filter_schemas is not None: sql = sql + " AND LOWER(c.table_schema) IN (" + ', '.join("'{}'".format(s).lower() for s in filter_schemas) + ")" - table_columns = self.query(sql) + table_columns = [] + if table_schemas or table_name: + sql = """SELECT LOWER(c.table_schema) table_schema, LOWER(c.table_name) table_name, c.column_name, c.data_type + FROM {}.columns c + """.format("information_schema" if not from_information_schema_cache_table else self.pipelinewise_schema) + if table_schemas: + sql = sql + " WHERE LOWER(c.table_schema) IN ({})".format(', '.join("'{}'".format(s).lower() for s in table_schemas)) + elif table_name: + sql = sql + " WHERE LOWER(c.table_name) = '{}'".format(table_name.lower()) + table_columns = self.query(sql) + else: + raise Exception("Cannot get table columns. List of table schemas empty") # Refresh cached information_schema if no results if from_information_schema_cache_table and len(table_columns) == 0: - self.cache_information_schema_columns() + self.cache_information_schema_columns(table_schemas=table_schemas) table_columns = self.query(sql) # Get columns from cache or information_schema and return results @@ -585,12 +592,11 @@ def update_columns(self): stream_schema_message = self.stream_schema_message stream = stream_schema_message['stream'] table_name = self.table_name(stream, False, True) - schema_name = self.schema_name columns = [] if self.information_schema_columns: columns = list(filter(lambda x: x['TABLE_SCHEMA'] == self.schema_name.lower() and x['TABLE_NAME'].lower() == table_name, self.information_schema_columns)) else: - columns = self.get_table_columns(table_schema=schema_name, table_name=table_name) + columns = self.get_table_columns(table_schemas=[self.schema_name], table_name=table_name) columns_dict = {column['COLUMN_NAME'].lower(): column for column in columns} columns_to_add = [ @@ -633,8 +639,8 @@ def update_columns(self): self.add_column(column, stream) # Refresh columns cache if required - if self.information_schema_columns is not None and (len(columns_to_add) > 0 or len(columns_to_replace)): - self.cache_information_schema_columns() + if self.information_schema_columns and (len(columns_to_add) > 0 or len(columns_to_replace)): + self.cache_information_schema_columns(table_schemas=[self.schema_name]) def drop_column(self, column_name, stream): drop_column = "ALTER TABLE {} DROP COLUMN {}".format(self.table_name(stream, False), column_name) @@ -671,8 +677,8 @@ def sync_table(self): self.grant_privilege(self.schema_name, self.grantees, self.grant_select_on_all_tables_in_schema) # Refresh columns cache if required - if self.information_schema_columns is not None: - self.cache_information_schema_columns() + if self.information_schema_columns: + self.cache_information_schema_columns(table_schemas=[self.schema_name]) else: logger.info("Table '{}' exists".format(table_name_with_schema)) self.update_columns() diff --git a/tests/integration/test_target_snowflake.py b/tests/integration/test_target_snowflake.py index 3e3a087a..20b6891a 100644 --- a/tests/integration/test_target_snowflake.py +++ b/tests/integration/test_target_snowflake.py @@ -41,6 +41,21 @@ def setUp(self): snowflake.query("DROP TABLE IF EXISTS {}.columns".format(snowflake.pipelinewise_schema)) + def persist_lines_with_cache(self, lines): + """Enables table caching option and loads singer messages into snowflake. + + Table caching mechanism is creating and maintaining an extra table in snowflake about + the table structures. It's very similar to the INFORMATION_SCHEMA.COLUMNS system views + but querying INFORMATION_SCHEMA is slow especially when lot of taps running + in parallel. + + Selecting from a real table instead of INFORMATION_SCHEMA and keeping it + in memory while the target-snowflake is running results better load performance. + """ + information_schema_cache = target_snowflake.load_information_schema_cache(self.config) + target_snowflake.persist_lines(self.config, lines, information_schema_cache) + + def remove_metadata_columns_from_rows(self, rows): """Removes metadata columns from a list of rows""" d_rows = [] @@ -158,14 +173,14 @@ def test_invalid_json(self): """Receiving invalid JSONs should raise an exception""" tap_lines = test_utils.get_test_tap_lines('invalid-json.json') with assert_raises(json.decoder.JSONDecodeError): - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) def test_message_order(self): """RECORD message without a previously received SCHEMA message should raise an exception""" tap_lines = test_utils.get_test_tap_lines('invalid-message-order.json') with assert_raises(Exception): - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) def test_loading_tables_with_no_encryption(self): @@ -174,7 +189,7 @@ def test_loading_tables_with_no_encryption(self): # Turning off client-side encryption and load self.config['client_side_encryption_master_key'] = '' - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) self.assert_three_streams_are_into_snowflake() @@ -185,7 +200,7 @@ def test_loading_tables_with_client_side_encryption(self): # Turning on client-side encryption and load self.config['client_side_encryption_master_key'] = os.environ.get('CLIENT_SIDE_ENCRYPTION_MASTER_KEY') - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) self.assert_three_streams_are_into_snowflake() @@ -197,7 +212,7 @@ def test_loading_tables_with_client_side_encryption_and_wrong_master_key(self): # Turning on client-side encryption and load but using a well formatted but wrong master key self.config['client_side_encryption_master_key'] = "Wr0n6m45t3rKeY0123456789a0123456789a0123456=" with assert_raises(snowflake.connector.errors.ProgrammingError): - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) def test_loading_tables_with_metadata_columns(self): @@ -206,7 +221,7 @@ def test_loading_tables_with_metadata_columns(self): # Turning on adding metadata columns self.config['add_metadata_columns'] = True - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) # Check if data loaded correctly and metadata columns exist self.assert_three_streams_are_into_snowflake(should_metadata_columns_exist=True) @@ -218,7 +233,7 @@ def test_loading_tables_with_hard_delete(self): # Turning on hard delete mode self.config['hard_delete'] = True - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) # Check if data loaded correctly and metadata columns exist self.assert_three_streams_are_into_snowflake( @@ -232,7 +247,7 @@ def test_loading_with_multiple_schema(self): tap_lines = test_utils.get_test_tap_lines('messages-with-multi-schemas.json') # Load with default settings - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) # Check if data loaded correctly self.assert_three_streams_are_into_snowflake( @@ -246,7 +261,7 @@ def test_loading_unicode_characters(self): tap_lines = test_utils.get_test_tap_lines('messages-with-unicode-characters.json') # Load with default settings - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) # Get loaded rows from tables snowflake = DbSync(self.config) @@ -270,7 +285,7 @@ def test_non_db_friendly_columns(self): tap_lines = test_utils.get_test_tap_lines('messages-with-non-db-friendly-columns.json') # Load with default settings - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) # Get loaded rows from tables snowflake = DbSync(self.config) @@ -293,7 +308,7 @@ def test_nested_schema_unflattening(self): tap_lines = test_utils.get_test_tap_lines('messages-with-nested-schema.json') # Load with default settings - Flattening disabled - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) # Get loaded rows from tables - Transform JSON to string at query time snowflake = DbSync(self.config) @@ -327,7 +342,7 @@ def test_nested_schema_flattening(self): self.config['data_flattening_max_level'] = 10 # Load with default settings - Flattening disabled - target_snowflake.persist_lines(self.config, tap_lines) + self.persist_lines_with_cache(tap_lines) # Get loaded rows from tables snowflake = DbSync(self.config) @@ -355,8 +370,8 @@ def test_column_name_change(self): tap_lines_after_column_name_change = test_utils.get_test_tap_lines('messages-with-three-streams-modified-column.json') # Load with default settings - target_snowflake.persist_lines(self.config, tap_lines_before_column_name_change) - target_snowflake.persist_lines(self.config, tap_lines_after_column_name_change) + self.persist_lines_with_cache(tap_lines_before_column_name_change) + self.persist_lines_with_cache(tap_lines_after_column_name_change) # Get loaded rows from tables snowflake = DbSync(self.config) @@ -413,8 +428,8 @@ def test_information_schema_cache_create_and_update(self): tap_lines_after_column_name_change = test_utils.get_test_tap_lines('messages-with-three-streams-modified-column.json') # Load with default settings - target_snowflake.persist_lines(self.config, tap_lines_before_column_name_change) - target_snowflake.persist_lines(self.config, tap_lines_after_column_name_change) + self.persist_lines_with_cache(tap_lines_before_column_name_change) + self.persist_lines_with_cache(tap_lines_after_column_name_change) # Get data form information_schema cache table snowflake = DbSync(self.config) @@ -474,7 +489,7 @@ def test_information_schema_cache_outdate(self): # Loading into an outdated information_schema cache should fail with table not exists with self.assertRaises(Exception): - target_snowflake.persist_lines(self.config, tap_lines_with_multi_streams) + self.persist_lines_with_cache(tap_lines_with_multi_streams) # 2) Simulate an out of data cache: # Table is in cache structure is not in sync with the actual table in the database @@ -484,4 +499,4 @@ def test_information_schema_cache_outdate(self): # Loading into an outdated information_schema cache should fail with columns exists # It should try adding the new column based on the values in cache but the column already exists with self.assertRaises(Exception): - target_snowflake.persist_lines(self.config, tap_lines_with_multi_streams) \ No newline at end of file + self.persist_lines_with_cache(tap_lines_with_multi_streams)