From 2ad21984ce89807f3f6defe77800854c83780c49 Mon Sep 17 00:00:00 2001 From: Samira El Aabidi <54845154+Samira-El@users.noreply.github.com> Date: Fri, 8 Jan 2021 15:06:52 +0200 Subject: [PATCH] Replace usage of LAST_QUERY_ID with query id from cursor (#130) * replace LAST_QUERY_ID with query id from cursor * update warning * start transaction outside of loop * update logging --- target_snowflake/db_sync.py | 60 +++++++++++++++------- tests/integration/test_target_snowflake.py | 11 ++-- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/target_snowflake/db_sync.py b/target_snowflake/db_sync.py index b69073e4..0f3562a6 100644 --- a/target_snowflake/db_sync.py +++ b/target_snowflake/db_sync.py @@ -1,5 +1,7 @@ import json import sys +from typing import Tuple, List, Dict, Optional, Union + import snowflake.connector import collections import inflection @@ -395,22 +397,38 @@ def open_connection(self): } ) - def query(self, query, params=None, max_records=0): + def query(self, query: Union[str, List[str]], params: Dict = None, max_records=0) -> List[Dict]: result = [] + + if params is None: + params = {} + else: + if 'LAST_QID' in params: + self.logger.warning('LAST_QID is a reserved prepared statement parameter name, ' + 'it will be overridden with each executed query!') + with self.open_connection() as connection: with connection.cursor(snowflake.connector.DictCursor) as cur: - queries = [] # Run every query in one transaction if query is a list of SQL if type(query) is list: - queries.append("START TRANSACTION") - queries.extend(query) + self.logger.info('Starting Transaction') + cur.execute("START TRANSACTION") + queries = query else: queries = [query] + qid = None + for q in queries: - self.logger.debug("Running query: {}".format(q)) + + # update the LAST_QID + params['LAST_QID'] = qid + + self.logger.info("Running query: '%s' with Params %s", q, params) + cur.execute(q, params) + qid = cur.sfqid # Raise exception if returned rows greater than max allowed records if 0 < max_records < cur.rowcount: @@ -634,14 +652,14 @@ def get_tables(self, table_schemas=[]): for schema in table_schemas: queries = [] - # Get column data types by SHOW COLUMNS + # Get tables in schema show_tables = f"SHOW TERSE TABLES IN SCHEMA {self.connection_config['dbname']}.{schema}" # Convert output of SHOW TABLES to table - select = f""" + select = """ SELECT "schema_name" AS schema_name ,"name" AS table_name - FROM TABLE(RESULT_SCAN(LAST_QUERY_ID())) + FROM TABLE(RESULT_SCAN(%(LAST_QID)s)) """ queries.extend([show_tables, select]) @@ -672,30 +690,36 @@ def get_table_columns(self, table_schemas=[]): show_columns = f"SHOW COLUMNS IN SCHEMA {self.connection_config['dbname']}.{schema}" # Convert output of SHOW COLUMNS to table and insert results into the cache COLUMNS table - select = f""" + # + # ---------------------------------------------------------------------------------------- + # Character and numeric columns display their generic data type rather than their defined + # data type (i.e. TEXT for all character types, FIXED for all fixed-point numeric types, + # and REAL for all floating-point numeric types). + # Further info at https://docs.snowflake.net/manuals/sql-reference/sql/show-columns.html + # ---------------------------------------------------------------------------------------- + select = """ SELECT "schema_name" AS schema_name ,"table_name" AS table_name ,"column_name" AS column_name - -- ---------------------------------------------------------------------------------------- - -- Character and numeric columns display their generic data type rather than their defined - -- data type (i.e. TEXT for all character types, FIXED for all fixed-point numeric types, - -- and REAL for all floating-point numeric types). - -- - -- Further info at https://docs.snowflake.net/manuals/sql-reference/sql/show-columns.html - -- ---------------------------------------------------------------------------------------- ,CASE PARSE_JSON("data_type"):type::varchar WHEN 'FIXED' THEN 'NUMBER' WHEN 'REAL' THEN 'FLOAT' ELSE PARSE_JSON("data_type"):type::varchar END data_type - FROM TABLE(RESULT_SCAN(LAST_QUERY_ID())) + FROM TABLE(RESULT_SCAN(%(LAST_QID)s)) """ + queries.extend([show_columns, select]) # Run everything in one transaction try: columns = self.query(queries, max_records=9999) - table_columns.extend(columns) + + if not columns: + self.logger.warning('No columns discovered in the schema "%s"', + f"{self.connection_config['dbname']}.{schema}") + else: + table_columns.extend(columns) # Catch exception when schema not exists and SHOW COLUMNS throws a ProgrammingError # Regexp to extract snowflake error code and message from the exception message diff --git a/tests/integration/test_target_snowflake.py b/tests/integration/test_target_snowflake.py index 93fb0751..454dd8dd 100644 --- a/tests/integration/test_target_snowflake.py +++ b/tests/integration/test_target_snowflake.py @@ -1055,11 +1055,12 @@ def test_query_tagging(self): self.persist_lines_with_cache(tap_lines) # Get query tags from QUERY_HISTORY - result = snowflake.query("SELECT query_tag, count(*) queries " - f"FROM table(information_schema.query_history_by_user('{self.config['user']}')) " - f"WHERE query_tag like '%PPW test tap run at {current_time}%'" - "GROUP BY query_tag " - "ORDER BY 1") + result = snowflake.query(f"""SELECT query_tag, count(*) queries + FROM table(information_schema.query_history_by_user('{self.config['user']}')) + WHERE query_tag like '%%PPW test tap run at {current_time}%%' + GROUP BY query_tag + ORDER BY 1""") + target_db = self.config['dbname'] target_schema = self.config['default_target_schema'] self.assertEqual(result, [{