diff --git a/target_snowflake/__init__.py b/target_snowflake/__init__.py index 8c50987f..cbf7abae 100644 --- a/target_snowflake/__init__.py +++ b/target_snowflake/__init__.py @@ -60,8 +60,15 @@ def emit_state(state): sys.stdout.flush() -def load_table_cache(config): - """Load table cache from snowflake metadata""" +def get_snowflake_statics(config): + """Retrieve common Snowflake items will be used multiple times + + Params: + config: configuration dictionary + + Returns: + tuple of retrieved items: table_cache, file_format_type + """ table_cache = [] if not ('disable_table_cache' in config and config['disable_table_cache']): LOGGER.info('Getting catalog objects from table cache...') @@ -70,12 +77,29 @@ def load_table_cache(config): table_cache = db.get_table_columns( table_schemas=stream_utils.get_schema_names_from_config(config)) - return table_cache + # The file format is detected at DbSync init time + file_format_type = db.file_format.file_format_type + return table_cache, file_format_type # pylint: disable=too-many-locals,too-many-branches,too-many-statements,invalid-name -def persist_lines(config, lines, table_cache=None) -> None: - """Main loop to read and consume singer messages from stdin""" +def persist_lines(config, lines, table_cache=None, file_format_type: FileFormatTypes=None) -> None: + """Main loop to read and consume singer messages from stdin + + Params: + config: configuration dictionary + lines: iterable of singer messages + table_cache: Optional dictionary of Snowflake table structures. This is useful to run the less + INFORMATION_SCHEMA and SHOW queries as possible. + If not provided then an SQL query will be generated at runtime to + get all the required information from Snowflake + file_format_type: Optional FileFormatTypes value that defines which supported file format to use + to load data into Snowflake. + If not provided then it will be detected automatically + + Returns: + tuple of retrieved items: table_cache, file_format_type + """ state = None flushed_state = None schemas = {} @@ -210,9 +234,12 @@ def persist_lines(config, lines, table_cache=None) -> None: key_properties[stream] = o['key_properties'] if config.get('add_metadata_columns') or config.get('hard_delete'): - stream_to_sync[stream] = DbSync(config, add_metadata_columns_to_schema(o), table_cache) + stream_to_sync[stream] = DbSync(config, + add_metadata_columns_to_schema(o), + table_cache, + file_format_type) else: - stream_to_sync[stream] = DbSync(config, o, table_cache) + stream_to_sync[stream] = DbSync(config, o, table_cache, file_format_type) stream_to_sync[stream].create_schema_if_not_exists() stream_to_sync[stream].sync_table() @@ -388,11 +415,11 @@ def main(): config = {} # Init columns cache - table_cache = load_table_cache(config) + table_cache, file_format_type = get_snowflake_statics(config) # Consume singer messages singer_messages = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') - persist_lines(config, singer_messages, table_cache) + persist_lines(config, singer_messages, table_cache, file_format_type) LOGGER.debug("Exiting normally") diff --git a/target_snowflake/db_sync.py b/target_snowflake/db_sync.py index e5f2aad9..4689785d 100644 --- a/target_snowflake/db_sync.py +++ b/target_snowflake/db_sync.py @@ -162,7 +162,7 @@ def create_query_tag(query_tag_pattern: str, database: str = None, schema: str = class DbSync: """DbSync class""" - def __init__(self, connection_config, stream_schema_message=None, table_cache=None): + def __init__(self, connection_config, stream_schema_message=None, table_cache=None, file_format_type=None): """ connection_config: Snowflake connection details @@ -205,7 +205,7 @@ def __init__(self, connection_config, stream_schema_message=None, table_cache=No self.schema_name = None self.grantees = None - self.file_format = FileFormat(self.connection_config['file_format'], self.query) + self.file_format = FileFormat(self.connection_config['file_format'], self.query, file_format_type) if not self.connection_config.get('stage') and self.file_format.file_format_type == FileFormatTypes.PARQUET: self.logger.error("Table stages with Parquet file format is not suppported. " diff --git a/target_snowflake/file_format.py b/target_snowflake/file_format.py index 8aab5cdb..3231212a 100644 --- a/target_snowflake/file_format.py +++ b/target_snowflake/file_format.py @@ -24,11 +24,15 @@ def list(): class FileFormat: """File Format class""" - def __init__(self, file_format: str, query_fn: Callable): + def __init__(self, file_format: str, query_fn: Callable, file_format_type: FileFormatTypes=None): """Find the file format in Snowflake, detect its type and initialise file format specific functions""" - # Detect file format type by querying it from Snowflake - self.file_format_type = self._detect_file_format_type(file_format, query_fn) + if file_format_type: + self.file_format_type = file_format_type + else: + # Detect file format type by querying it from Snowflake + self.file_format_type = self._detect_file_format_type(file_format, query_fn) + self.formatter = self._get_formatter(self.file_format_type) @classmethod diff --git a/tests/integration/test_target_snowflake.py b/tests/integration/test_target_snowflake.py index b4136b9f..2a122c00 100644 --- a/tests/integration/test_target_snowflake.py +++ b/tests/integration/test_target_snowflake.py @@ -55,8 +55,8 @@ def persist_lines_with_cache(self, lines): Selecting from a real table instead of INFORMATION_SCHEMA and keeping it in memory while the target-snowflake is running results better load performance. """ - table_cache = target_snowflake.load_table_cache(self.config) - target_snowflake.persist_lines(self.config, lines, table_cache) + table_cache, file_format_type = target_snowflake.get_snowflake_statics(self.config) + target_snowflake.persist_lines(self.config, lines, table_cache, file_format_type) def remove_metadata_columns_from_rows(self, rows): """Removes metadata columns from a list of rows""" @@ -1072,18 +1072,6 @@ def test_query_tagging(self): 'QUERIES': 6 }, { - 'QUERY_TAG': f'PPW test tap run at {current_time}. Loading into {target_db}..TEST_TABLE_ONE', - 'QUERIES': 2 - }, - { - 'QUERY_TAG': f'PPW test tap run at {current_time}. Loading into {target_db}..TEST_TABLE_THREE', - 'QUERIES': 2 - }, - { - 'QUERY_TAG': f'PPW test tap run at {current_time}. Loading into {target_db}..TEST_TABLE_TWO', - 'QUERIES': 2 - }, - { 'QUERY_TAG': f'PPW test tap run at {current_time}. Loading into {target_db}.{target_schema}.TEST_TABLE_ONE', 'QUERIES': 12 }, @@ -1097,6 +1085,16 @@ def test_query_tagging(self): } ]) + # Detecting file format type should run only once + result = snowflake.query(f"""SELECT count(*) show_file_format_queries + FROM table(information_schema.query_history_by_user('{self.config['user']}')) + WHERE query_tag like '%%PPW test tap run at {current_time}%%' + AND query_text like 'SHOW FILE FORMATS%%'""") + self.assertEqual(result, [{ + 'SHOW_FILE_FORMAT_QUERIES': 1 + } + ]) + def test_table_stage(self): """Test if data can be loaded via table stages""" tap_lines = test_utils.get_test_tap_lines('messages-with-three-streams.json')