Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Local type casts #173

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions tap_postgres/sync_strategies/logical_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import decimal
import psycopg2
import copy
import csv
import json
import re
import singer
import warnings

from select import select
from psycopg2 import sql
from singer import metadata, utils, get_bookmark
from dateutil.parser import parse, UnknownTimezoneWarning, ParserError
from functools import reduce
from psycopg2.extras import HstoreAdapter

import tap_postgres.db as post_db
import tap_postgres.sync_strategies.common as sync_common
Expand Down Expand Up @@ -126,37 +126,35 @@ def get_stream_version(tap_stream_id, state):
return stream_version


def tuples_to_map(accum, t):
accum[t[0]] = t[1]
return accum
def create_hstore_elem(elem):
return HstoreAdapter.parse(elem, None)


def create_hstore_elem_query(elem):
return sql.SQL("SELECT hstore_to_array({})").format(sql.Literal(elem))


def create_hstore_elem(conn_info, elem):
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
query = create_hstore_elem_query(elem)
cur.execute(query)
res = cur.fetchone()[0]
hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {})
return hstore_elem
def local_array_parsing(elem, cast=None):
elem = [elem[1:-1]]
reader = csv.reader(elem, delimiter=',', escapechar='\\', quotechar='"')
array = next(reader)
array = [None if element.lower() == 'null' else cast(element) if cast else element for element in array]
return array


def create_array_elem(elem, sql_datatype, conn_info):
if elem is None:
return None

if sql_datatype == 'text[]':
return local_array_parsing(elem)
elif sql_datatype == 'integer[]':
return local_array_parsing(elem, int)
elif sql_datatype == 'character varying[]':
return local_array_parsing(elem)

with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
if sql_datatype == 'bit[]':
cast_datatype = 'boolean[]'
elif sql_datatype == 'boolean[]':
cast_datatype = 'boolean[]'
elif sql_datatype == 'character varying[]':
cast_datatype = 'character varying[]'
elif sql_datatype == 'cidr[]':
cast_datatype = 'cidr[]'
elif sql_datatype == 'citext[]':
Expand All @@ -167,8 +165,6 @@ def create_array_elem(elem, sql_datatype, conn_info):
cast_datatype = 'double precision[]'
elif sql_datatype == 'hstore[]':
cast_datatype = 'text[]'
elif sql_datatype == 'integer[]':
cast_datatype = 'integer[]'
elif sql_datatype == 'inet[]':
cast_datatype = 'inet[]'
elif sql_datatype == 'json[]':
Expand All @@ -185,8 +181,6 @@ def create_array_elem(elem, sql_datatype, conn_info):
cast_datatype = 'real[]'
elif sql_datatype == 'smallint[]':
cast_datatype = 'smallint[]'
elif sql_datatype == 'text[]':
cast_datatype = 'text[]'
elif sql_datatype in ('time without time zone[]', 'time with time zone[]'):
cast_datatype = 'text[]'
elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'):
Expand All @@ -205,7 +199,7 @@ def create_array_elem(elem, sql_datatype, conn_info):


# pylint: disable=too-many-branches,too-many-nested-blocks,too-many-return-statements
def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info):
def selected_value_to_singer_value_impl(elem, og_sql_datatype):
sql_datatype = og_sql_datatype.replace('[]', '')

if elem is None:
Expand Down Expand Up @@ -321,7 +315,7 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info):
if sql_datatype == 'boolean':
return elem
if sql_datatype == 'hstore':
return create_hstore_elem(conn_info, elem)
return create_hstore_elem(elem)
if 'numeric' in sql_datatype:
return decimal.Decimal(elem)
if isinstance(elem, int):
Expand All @@ -338,7 +332,7 @@ def selected_array_to_singer_value(elem, sql_datatype, conn_info):
if isinstance(elem, list):
return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), elem))

return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info)
return selected_value_to_singer_value_impl(elem, sql_datatype)


def selected_value_to_singer_value(elem, sql_datatype, conn_info):
Expand All @@ -348,7 +342,7 @@ def selected_value_to_singer_value(elem, sql_datatype, conn_info):
return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info),
(cleaned_elem or [])))

return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info)
return selected_value_to_singer_value_impl(elem, sql_datatype)


def row_to_singer_message(stream, row, version, columns, time_extracted, md_map, conn_info):
Expand Down
11 changes: 0 additions & 11 deletions tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,6 @@ def test_catalog(self):
'definitions' : BASE_RECURSIVE_SCHEMAS},
stream_dict.get('schema'))

def test_escaping_values(self):
key = 'nickname'
value = "Dave's Courtyard"
elem = '"{}"=>"{}"'.format(key, value)

with get_test_connection() as conn:
with conn.cursor() as cur:
query = tap_postgres.sync_strategies.logical_replication.create_hstore_elem_query(elem)
self.assertEqual(query.as_string(cur), "SELECT hstore_to_array('\"nickname\"=>\"Dave''s Courtyard\"')")


class TestEnumTable(unittest.TestCase):
maxDiff = None
table_name = 'CHICKEN TIMES'
Expand Down
Loading