From 83761af19d57b09786114abab4b3afa6c15ef735 Mon Sep 17 00:00:00 2001 From: Jeet Parekh <94441288+jeet-parekh-wise@users.noreply.github.com> Date: Fri, 18 Feb 2022 11:25:36 +0000 Subject: [PATCH] [AP-1122] improve logging for failed MERGE and COPY queries (#250) * [AP-1122] improve logging for failed MERGE and COPY queries * [AP-1122] fix pylint complaints * [AP-1122] fix DeprecationWarning * [AP-1122] add tests * [AP-1122] restructure code * [AP-1122] modify target_snowflake/flattening.py to increase test coverage --- target_snowflake/db_sync.py | 129 +++++++++++++++++++++------------ target_snowflake/flattening.py | 16 ++-- tests/unit/test_db_sync.py | 78 +++++++++++++++++++- 3 files changed, 166 insertions(+), 57 deletions(-) diff --git a/target_snowflake/db_sync.py b/target_snowflake/db_sync.py index 9c963e82..c253be89 100644 --- a/target_snowflake/db_sync.py +++ b/target_snowflake/db_sync.py @@ -1,6 +1,6 @@ import json import sys -from typing import List, Dict, Union +from typing import List, Dict, Union, Tuple import snowflake.connector import re @@ -434,8 +434,7 @@ def get_stage_name(self, stream): def load_file(self, s3_key, count, size_bytes): """Load a supported file type from snowflake stage into target table""" - stream_schema_message = self.stream_schema_message - stream = stream_schema_message['stream'] + stream = self.stream_schema_message['stream'] self.logger.info("Loading %d rows into '%s'", count, self.table_name(stream, False)) # Get list if columns with types @@ -448,50 +447,90 @@ def load_file(self, s3_key, count, size_bytes): for (name, schema) in self.flatten_schema.items() ] + inserts = 0 + updates = 0 + + # Insert or Update with MERGE command if primary key defined + if len(self.stream_schema_message['key_properties']) > 0: + try: + inserts, updates = self._load_file_merge( + s3_key=s3_key, + stream=stream, + columns_with_trans=columns_with_trans + ) + except Exception as ex: + self.logger.error( + 'Error while executing MERGE query for table "%s" in stream "%s"', + self.table_name(stream, False), stream + ) + raise ex + + # Insert only with COPY command if no primary key + else: + try: + inserts, updates = ( + self._load_file_copy( + s3_key=s3_key, + stream=stream, + columns_with_trans=columns_with_trans + ), + 0, + ) + except Exception as ex: + self.logger.error( + 'Error while executing COPY query for table "%s" in stream "%s"', + self.table_name(stream, False), stream + ) + raise ex + + self.logger.info( + 'Loading into %s: %s', + self.table_name(stream, False), + json.dumps({'inserts': inserts, 'updates': updates, 'size_bytes': size_bytes}) + ) + + def _load_file_merge(self, s3_key, stream, columns_with_trans) -> Tuple[int, int]: + # MERGE does insert and update + inserts = 0 + updates = 0 with self.open_connection() as connection: with connection.cursor(snowflake.connector.DictCursor) as cur: - inserts = 0 - updates = 0 - - # Insert or Update with MERGE command if primary key defined - if len(self.stream_schema_message['key_properties']) > 0: - merge_sql = self.file_format.formatter.create_merge_sql(table_name=self.table_name(stream, False), - stage_name=self.get_stage_name(stream), - s3_key=s3_key, - file_format_name= - self.connection_config['file_format'], - columns=columns_with_trans, - pk_merge_condition= - self.primary_key_merge_condition()) - self.logger.debug('Running query: %s', merge_sql) - cur.execute(merge_sql) - - # Get number of inserted and updated records - MERGE does insert and update - results = cur.fetchall() - if len(results) > 0: - inserts = results[0].get('number of rows inserted', 0) - updates = results[0].get('number of rows updated', 0) - - # Insert only with COPY command if no primary key - else: - copy_sql = self.file_format.formatter.create_copy_sql(table_name=self.table_name(stream, False), - stage_name=self.get_stage_name(stream), - s3_key=s3_key, - file_format_name= - self.connection_config['file_format'], - columns=columns_with_trans) - self.logger.debug('Running query: %s', copy_sql) - cur.execute(copy_sql) - - # Get number of inserted records - COPY does insert only - results = cur.fetchall() - if len(results) > 0: - inserts = results[0].get('rows_loaded', 0) - - self.logger.info('Loading into %s: %s', - self.table_name(stream, False), - json.dumps({'inserts': inserts, 'updates': updates, 'size_bytes': size_bytes})) - + merge_sql = self.file_format.formatter.create_merge_sql( + table_name=self.table_name(stream, False), + stage_name=self.get_stage_name(stream), + s3_key=s3_key, + file_format_name=self.connection_config['file_format'], + columns=columns_with_trans, + pk_merge_condition=self.primary_key_merge_condition() + ) + self.logger.debug('Running query: %s', merge_sql) + cur.execute(merge_sql) + # Get number of inserted and updated records + results = cur.fetchall() + if len(results) > 0: + inserts = results[0].get('number of rows inserted', 0) + updates = results[0].get('number of rows updated', 0) + return inserts, updates + + def _load_file_copy(self, s3_key, stream, columns_with_trans) -> int: + # COPY does insert only + inserts = 0 + with self.open_connection() as connection: + with connection.cursor(snowflake.connector.DictCursor) as cur: + copy_sql = self.file_format.formatter.create_copy_sql( + table_name=self.table_name(stream, False), + stage_name=self.get_stage_name(stream), + s3_key=s3_key, + file_format_name=self.connection_config['file_format'], + columns=columns_with_trans + ) + self.logger.debug('Running query: %s', copy_sql) + cur.execute(copy_sql) + # Get number of inserted records - COPY does insert only + results = cur.fetchall() + if len(results) > 0: + inserts = results[0].get('rows_loaded', 0) + return inserts def primary_key_merge_condition(self): """Generate SQL join condition on primary keys for merge SQL statements""" stream_schema_message = self.stream_schema_message diff --git a/target_snowflake/flattening.py b/target_snowflake/flattening.py index 40ab48ec..a6536ecc 100644 --- a/target_snowflake/flattening.py +++ b/target_snowflake/flattening.py @@ -52,17 +52,11 @@ def flatten_schema(d, parent_key=None, sep='__', level=0, max_level=0): items.extend(flatten_schema(v, parent_key + [k], sep=sep, level=level + 1, max_level=max_level).items()) else: items.append((new_key, v)) - else: - if len(v.values()) > 0: - if list(v.values())[0][0]['type'] == 'string': - list(v.values())[0][0]['type'] = ['null', 'string'] - items.append((new_key, list(v.values())[0][0])) - elif list(v.values())[0][0]['type'] == 'array': - list(v.values())[0][0]['type'] = ['null', 'array'] - items.append((new_key, list(v.values())[0][0])) - elif list(v.values())[0][0]['type'] == 'object': - list(v.values())[0][0]['type'] = ['null', 'object'] - items.append((new_key, list(v.values())[0][0])) + elif len(v.values()) > 0: + value_type = list(v.values())[0][0]['type'] + if value_type in ['string', 'array', 'object']: + list(v.values())[0][0]['type'] = ['null', value_type] + items.append((new_key, list(v.values())[0][0])) key_func = lambda item: item[0] sorted_items = sorted(items, key=key_func) diff --git a/tests/unit/test_db_sync.py b/tests/unit/test_db_sync.py index 296d5531..f7cfdb7f 100644 --- a/tests/unit/test_db_sync.py +++ b/tests/unit/test_db_sync.py @@ -310,5 +310,81 @@ def test_record_primary_key_string(self, query_patch): stream_schema_message['key_properties'] = ['invalid_col'] dbsync = db_sync.DbSync(minimal_config, stream_schema_message) with self.assertRaisesRegex(PrimaryKeyNotFoundException, - "Cannot find \['invalid_col'\] primary key\(s\) in record\. Available fields: \['id', 'c_str'\]"): + r"Cannot find \['invalid_col'\] primary key\(s\) in record\. Available fields: \['id', 'c_str'\]"): dbsync.record_primary_key_string({'id': 123, 'c_str': 'xyz'}) + + @patch('target_snowflake.db_sync.DbSync.query') + @patch('target_snowflake.db_sync.DbSync._load_file_merge') + def test_merge_failure_message(self, load_file_merge_patch, query_patch): + LOGGER_NAME = "target_snowflake" + query_patch.return_value = [{'type': 'CSV'}] + minimal_config = { + 'account': "dummy_account", + 'dbname': "dummy_dbname", + 'user': "dummy_user", + 'password': "dummy_password", + 'warehouse': "dummy_warehouse", + 'default_target_schema': "dummy_default_target_schema", + 'file_format': "dummy_file_format", + } + + stream_schema_message = { + "stream": "dummy_stream", + "schema": { + "properties": { + "id": {"type": ["integer"]}, + "c_str": {"type": ["null", "string"]} + } + }, + "key_properties": ["id"] + } + + # Single primary key string + dbsync = db_sync.DbSync(minimal_config, stream_schema_message) + load_file_merge_patch.side_effect = Exception() + expected_msg = ( + f'ERROR:{LOGGER_NAME}:Error while executing MERGE query ' + f'for table "{minimal_config["default_target_schema"]}."{stream_schema_message["stream"].upper()}"" ' + f'in stream "{stream_schema_message["stream"]}"' + ) + with self.assertRaises(Exception), self.assertLogs(logger=LOGGER_NAME, level="ERROR") as captured_logs: + dbsync.load_file(s3_key="dummy-key", count=256, size_bytes=256) + self.assertIn(expected_msg, captured_logs.output) + + @patch('target_snowflake.db_sync.DbSync.query') + @patch('target_snowflake.db_sync.DbSync._load_file_copy') + def test_copy_failure_message(self, load_file_copy_patch, query_patch): + LOGGER_NAME = "target_snowflake" + query_patch.return_value = [{'type': 'CSV'}] + minimal_config = { + 'account': "dummy_account", + 'dbname': "dummy_dbname", + 'user': "dummy_user", + 'password': "dummy_password", + 'warehouse': "dummy_warehouse", + 'default_target_schema': "dummy_default_target_schema", + 'file_format': "dummy_file_format", + } + + stream_schema_message = { + "stream": "dummy_stream", + "schema": { + "properties": { + "id": {"type": ["integer"]}, + "c_str": {"type": ["null", "string"]} + } + }, + "key_properties": [] + } + + # Single primary key string + dbsync = db_sync.DbSync(minimal_config, stream_schema_message) + load_file_copy_patch.side_effect = Exception() + expected_msg = ( + f'ERROR:{LOGGER_NAME}:Error while executing COPY query ' + f'for table "{minimal_config["default_target_schema"]}."{stream_schema_message["stream"].upper()}"" ' + f'in stream "{stream_schema_message["stream"]}"' + ) + with self.assertRaises(Exception), self.assertLogs(logger=LOGGER_NAME, level="ERROR") as captured_logs: + dbsync.load_file(s3_key="dummy-key", count=256, size_bytes=256) + self.assertIn(expected_msg, captured_logs.output)