Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Commit

Permalink
Replace usage of LAST_QUERY_ID with query id from cursor (#130)
Browse files Browse the repository at this point in the history
* replace LAST_QUERY_ID with query id from cursor

* update warning

* start transaction outside of loop

* update logging
  • Loading branch information
Samira-El authored Jan 8, 2021
1 parent 6ff2999 commit 2ad2198
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 23 deletions.
60 changes: 42 additions & 18 deletions target_snowflake/db_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import sys
from typing import Tuple, List, Dict, Optional, Union

import snowflake.connector
import collections
import inflection
Expand Down Expand Up @@ -395,22 +397,38 @@ def open_connection(self):
}
)

def query(self, query, params=None, max_records=0):
def query(self, query: Union[str, List[str]], params: Dict = None, max_records=0) -> List[Dict]:
result = []

if params is None:
params = {}
else:
if 'LAST_QID' in params:
self.logger.warning('LAST_QID is a reserved prepared statement parameter name, '
'it will be overridden with each executed query!')

with self.open_connection() as connection:
with connection.cursor(snowflake.connector.DictCursor) as cur:
queries = []

# Run every query in one transaction if query is a list of SQL
if type(query) is list:
queries.append("START TRANSACTION")
queries.extend(query)
self.logger.info('Starting Transaction')
cur.execute("START TRANSACTION")
queries = query
else:
queries = [query]

qid = None

for q in queries:
self.logger.debug("Running query: {}".format(q))

# update the LAST_QID
params['LAST_QID'] = qid

self.logger.info("Running query: '%s' with Params %s", q, params)

cur.execute(q, params)
qid = cur.sfqid

# Raise exception if returned rows greater than max allowed records
if 0 < max_records < cur.rowcount:
Expand Down Expand Up @@ -634,14 +652,14 @@ def get_tables(self, table_schemas=[]):
for schema in table_schemas:
queries = []

# Get column data types by SHOW COLUMNS
# Get tables in schema
show_tables = f"SHOW TERSE TABLES IN SCHEMA {self.connection_config['dbname']}.{schema}"

# Convert output of SHOW TABLES to table
select = f"""
select = """
SELECT "schema_name" AS schema_name
,"name" AS table_name
FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()))
FROM TABLE(RESULT_SCAN(%(LAST_QID)s))
"""
queries.extend([show_tables, select])

Expand Down Expand Up @@ -672,30 +690,36 @@ def get_table_columns(self, table_schemas=[]):
show_columns = f"SHOW COLUMNS IN SCHEMA {self.connection_config['dbname']}.{schema}"

# Convert output of SHOW COLUMNS to table and insert results into the cache COLUMNS table
select = f"""
#
# ----------------------------------------------------------------------------------------
# Character and numeric columns display their generic data type rather than their defined
# data type (i.e. TEXT for all character types, FIXED for all fixed-point numeric types,
# and REAL for all floating-point numeric types).
# Further info at https://docs.snowflake.net/manuals/sql-reference/sql/show-columns.html
# ----------------------------------------------------------------------------------------
select = """
SELECT "schema_name" AS schema_name
,"table_name" AS table_name
,"column_name" AS column_name
-- ----------------------------------------------------------------------------------------
-- Character and numeric columns display their generic data type rather than their defined
-- data type (i.e. TEXT for all character types, FIXED for all fixed-point numeric types,
-- and REAL for all floating-point numeric types).
--
-- Further info at https://docs.snowflake.net/manuals/sql-reference/sql/show-columns.html
-- ----------------------------------------------------------------------------------------
,CASE PARSE_JSON("data_type"):type::varchar
WHEN 'FIXED' THEN 'NUMBER'
WHEN 'REAL' THEN 'FLOAT'
ELSE PARSE_JSON("data_type"):type::varchar
END data_type
FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()))
FROM TABLE(RESULT_SCAN(%(LAST_QID)s))
"""

queries.extend([show_columns, select])

# Run everything in one transaction
try:
columns = self.query(queries, max_records=9999)
table_columns.extend(columns)

if not columns:
self.logger.warning('No columns discovered in the schema "%s"',
f"{self.connection_config['dbname']}.{schema}")
else:
table_columns.extend(columns)

# Catch exception when schema not exists and SHOW COLUMNS throws a ProgrammingError
# Regexp to extract snowflake error code and message from the exception message
Expand Down
11 changes: 6 additions & 5 deletions tests/integration/test_target_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,11 +1055,12 @@ def test_query_tagging(self):
self.persist_lines_with_cache(tap_lines)

# Get query tags from QUERY_HISTORY
result = snowflake.query("SELECT query_tag, count(*) queries "
f"FROM table(information_schema.query_history_by_user('{self.config['user']}')) "
f"WHERE query_tag like '%PPW test tap run at {current_time}%'"
"GROUP BY query_tag "
"ORDER BY 1")
result = snowflake.query(f"""SELECT query_tag, count(*) queries
FROM table(information_schema.query_history_by_user('{self.config['user']}'))
WHERE query_tag like '%%PPW test tap run at {current_time}%%'
GROUP BY query_tag
ORDER BY 1""")

target_db = self.config['dbname']
target_schema = self.config['default_target_schema']
self.assertEqual(result, [{
Expand Down

0 comments on commit 2ad2198

Please sign in to comment.