From 98eddbed8c18da4a1fedd7e733ea890e65e598a8 Mon Sep 17 00:00:00 2001 From: Nils Mueller Date: Tue, 19 Apr 2022 00:06:22 +0200 Subject: [PATCH 1/2] Locally parse hstore values Saves a lot of network round trips. --- .../sync_strategies/logical_replication.py | 30 +--- tests/test_discovery.py | 11 -- tests/test_logical_replication.py | 155 ++++++------------ 3 files changed, 54 insertions(+), 142 deletions(-) diff --git a/tap_postgres/sync_strategies/logical_replication.py b/tap_postgres/sync_strategies/logical_replication.py index 8056617a..e5de1aa4 100644 --- a/tap_postgres/sync_strategies/logical_replication.py +++ b/tap_postgres/sync_strategies/logical_replication.py @@ -9,10 +9,9 @@ import warnings from select import select -from psycopg2 import sql from singer import metadata, utils, get_bookmark from dateutil.parser import parse, UnknownTimezoneWarning, ParserError -from functools import reduce +from psycopg2.extras import HstoreAdapter import tap_postgres.db as post_db import tap_postgres.sync_strategies.common as sync_common @@ -126,23 +125,8 @@ def get_stream_version(tap_stream_id, state): return stream_version -def tuples_to_map(accum, t): - accum[t[0]] = t[1] - return accum - - -def create_hstore_elem_query(elem): - return sql.SQL("SELECT hstore_to_array({})").format(sql.Literal(elem)) - - -def create_hstore_elem(conn_info, elem): - with post_db.open_connection(conn_info, False, True) as conn: - with conn.cursor() as cur: - query = create_hstore_elem_query(elem) - cur.execute(query) - res = cur.fetchone()[0] - hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {}) - return hstore_elem +def create_hstore_elem(elem): + return HstoreAdapter.parse(elem, None) def create_array_elem(elem, sql_datatype, conn_info): @@ -205,7 +189,7 @@ def create_array_elem(elem, sql_datatype, conn_info): # pylint: disable=too-many-branches,too-many-nested-blocks,too-many-return-statements -def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): +def selected_value_to_singer_value_impl(elem, og_sql_datatype): sql_datatype = og_sql_datatype.replace('[]', '') if elem is None: @@ -321,7 +305,7 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): if sql_datatype == 'boolean': return elem if sql_datatype == 'hstore': - return create_hstore_elem(conn_info, elem) + return create_hstore_elem(elem) if 'numeric' in sql_datatype: return decimal.Decimal(elem) if isinstance(elem, int): @@ -338,7 +322,7 @@ def selected_array_to_singer_value(elem, sql_datatype, conn_info): if isinstance(elem, list): return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), elem)) - return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info) + return selected_value_to_singer_value_impl(elem, sql_datatype) def selected_value_to_singer_value(elem, sql_datatype, conn_info): @@ -348,7 +332,7 @@ def selected_value_to_singer_value(elem, sql_datatype, conn_info): return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), (cleaned_elem or []))) - return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info) + return selected_value_to_singer_value_impl(elem, sql_datatype) def row_to_singer_message(stream, row, version, columns, time_extracted, md_map, conn_info): diff --git a/tests/test_discovery.py b/tests/test_discovery.py index bb9eb35a..95ba2dab 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -365,17 +365,6 @@ def test_catalog(self): 'definitions' : BASE_RECURSIVE_SCHEMAS}, stream_dict.get('schema')) - def test_escaping_values(self): - key = 'nickname' - value = "Dave's Courtyard" - elem = '"{}"=>"{}"'.format(key, value) - - with get_test_connection() as conn: - with conn.cursor() as cur: - query = tap_postgres.sync_strategies.logical_replication.create_hstore_elem_query(elem) - self.assertEqual(query.as_string(cur), "SELECT hstore_to_array('\"nickname\"=>\"Dave''s Courtyard\"')") - - class TestEnumTable(unittest.TestCase): maxDiff = None table_name = 'CHICKEN TIMES' diff --git a/tests/test_logical_replication.py b/tests/test_logical_replication.py index 228d79b4..5471c9a3 100644 --- a/tests/test_logical_replication.py +++ b/tests/test_logical_replication.py @@ -261,15 +261,13 @@ def test_consume_message_with_new_column_in_payload_will_refresh_schema(self, def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_expect_iso_format(self): output = logical_replication.selected_value_to_singer_value_impl('2020-09-01 20:10:56', - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('2020-09-01T20:10:56+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_expect_iso_format(self): output = logical_replication.selected_value_to_singer_value_impl(datetime(2020, 9, 1, 20, 10, 59), - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('2020-09-01T20:10:59+00:00', output) @@ -279,8 +277,7 @@ def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_ should fallback to max datetime allowed """ output = logical_replication.selected_value_to_singer_value_impl('10000-09-01 20:10:56', - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) @@ -290,8 +287,7 @@ def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_ should fallback to max datetime allowed """ output = logical_replication.selected_value_to_singer_value_impl('0000-09-01 20:10:56', - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) @@ -301,8 +297,7 @@ def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_ should fallback to max datetime allowed """ output = logical_replication.selected_value_to_singer_value_impl('1000-09-01 20:10:56 BC', - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) @@ -312,51 +307,44 @@ def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_ should fallback to max datetime allowed """ output = logical_replication.selected_value_to_singer_value_impl('1000-09-01 20:10:56 AC', - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_min(self): output = logical_replication.selected_value_to_singer_value_impl('0001-01-01 00:00:00.000123', - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('0001-01-01T00:00:00.000123+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_string_max(self): output = logical_replication.selected_value_to_singer_value_impl('9999-12-31 23:59:59.999999', - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_min(self): output = logical_replication.selected_value_to_singer_value_impl(datetime(1, 1, 1, 0, 0, 0, 123), - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('0001-01-01T00:00:00.000123+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_ntz_value_as_datetime_max(self): output = logical_replication.selected_value_to_singer_value_impl(datetime(9999, 12, 31, 23, 59, 59, 999999), - 'timestamp without time zone', - None) + 'timestamp without time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_expect_iso_format(self): output = logical_replication.selected_value_to_singer_value_impl('2020-09-01 20:10:56+05', - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('2020-09-01T20:10:56+05:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_expect_iso_format(self): output = logical_replication.selected_value_to_singer_value_impl(datetime(2020, 9, 1, 23, 10, 59, tzinfo=tzoffset(None, -3600)), - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('2020-09-01T23:10:59-01:00', output) @@ -366,8 +354,7 @@ def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_o should fallback to max datetime allowed """ output = logical_replication.selected_value_to_singer_value_impl('10000-09-01 20:10:56+06', - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) @@ -377,8 +364,7 @@ def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_o should fallback to max datetime allowed """ output = logical_replication.selected_value_to_singer_value_impl('0000-09-01 20:10:56+01', - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) @@ -388,8 +374,7 @@ def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_B should fallback to max datetime allowed """ output = logical_replication.selected_value_to_singer_value_impl('1000-09-01 20:10:56+05 BC', - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) @@ -399,43 +384,38 @@ def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_A should fallback to max datetime allowed """ output = logical_replication.selected_value_to_singer_value_impl('1000-09-01 20:10:56-09 AC', - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_min(self): output = logical_replication.selected_value_to_singer_value_impl('0001-01-01 00:00:00.000123+04', - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_string_max(self): output = logical_replication.selected_value_to_singer_value_impl('9999-12-31 23:59:59.999999-03', - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_min(self): output = logical_replication.selected_value_to_singer_value_impl(datetime(1, 1, 1, 0, 0, 0, 123, tzinfo=tzoffset(None, 14400)), - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) def test_selected_value_to_singer_value_impl_with_timestamp_tz_value_as_datetime_max(self): output = logical_replication.selected_value_to_singer_value_impl(datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=tzoffset(None, -14400)), - 'timestamp with time zone', - None) + 'timestamp with time zone') self.assertEqual('9999-12-31T23:59:59.999+00:00', output) def test_selected_value_to_singer_value_impl_with_date_value_as_string_expect_iso_format(self): - output = logical_replication.selected_value_to_singer_value_impl('2021-09-07', 'date', None) + output = logical_replication.selected_value_to_singer_value_impl('2021-09-07', 'date') self.assertEqual('2021-09-07T00:00:00+00:00', output) @@ -445,7 +425,7 @@ def test_selected_value_to_singer_value_impl_with_date_value_as_string_out_of_ra is > 9999 (which is valid in postgres) should fallback to max date allowed """ - output = logical_replication.selected_value_to_singer_value_impl('10000-09-01', 'date', None) + output = logical_replication.selected_value_to_singer_value_impl('10000-09-01', 'date') self.assertEqual('9999-12-31T00:00:00+00:00', output) @@ -513,22 +493,19 @@ def test_row_to_singer_message(self): def test_selected_value_to_singer_value_impl_with_null_json_returns_None(self): output = logical_replication.selected_value_to_singer_value_impl(None, - 'json', - None) + 'json') self.assertEqual(None, output) def test_selected_value_to_singer_value_impl_with_empty_json_returns_empty_dict(self): output = logical_replication.selected_value_to_singer_value_impl('{}', - 'json', - None) + 'json') self.assertEqual({}, output) def test_selected_value_to_singer_value_impl_with_non_empty_json_returns_equivalent_dict(self): output = logical_replication.selected_value_to_singer_value_impl('{"key1": "A", "key2": [{"kk": "yo"}, {}]}', - 'json', - None) + 'json') self.assertEqual({ 'key1': 'A', @@ -537,22 +514,19 @@ def test_selected_value_to_singer_value_impl_with_non_empty_json_returns_equival def test_selected_value_to_singer_value_impl_with_null_jsonb_returns_None(self): output = logical_replication.selected_value_to_singer_value_impl(None, - 'jsonb', - None) + 'jsonb') self.assertEqual(None, output) def test_selected_value_to_singer_value_impl_with_empty_jsonb_returns_empty_dict(self): output = logical_replication.selected_value_to_singer_value_impl('{}', - 'jsonb', - None) + 'jsonb') self.assertEqual({}, output) def test_selected_value_to_singer_value_impl_with_non_empty_jsonb_returns_equivalent_dict(self): output = logical_replication.selected_value_to_singer_value_impl('{"key1": "A", "key2": [{"kk": "yo"}, {}]}', - 'jsonb', - None) + 'jsonb') self.assertEqual({ 'key1': 'A', @@ -679,25 +653,10 @@ def test_get_stream_version_not_none(self): actual_value = logical_replication.get_stream_version(tap_stream_id, state) self.assertEqual(state['bookmarks']['foo']['version'], actual_value) - def test_tuples_to_map(self): - """Test if the output of tuples_to_map is as expected""" - accum = {'foo_key': 'foo_value'} - t = ['bar_1', 'bar_2'] - expected_output = accum.copy() - expected_output[t[0]] = t[1] - - actual_output = logical_replication.tuples_to_map(accum, t) - self.assertEqual(expected_output, actual_output) - - @patch("psycopg2.connect") - def test_create_hstore_elem(self, mocked_connect): - """Test if the output of create_hstore_elem is as expected""" - mocked_cursor = mocked_connect.return_value.__enter__.return_value.cursor - mocked_fetchone = mocked_cursor.return_value.__enter__.return_value.fetchone - mocked_fetchone.return_value = (['foo', 'bar'],) - elem = 'foo=>bar' + def test_create_hstore_elem(self): + elem = '"foo"=>"bar"' expected_output = {'foo': 'bar'} - actual_output = logical_replication.create_hstore_elem(self.conn_info, elem) + actual_output = logical_replication.create_hstore_elem(elem) self.assertDictEqual(expected_output, actual_output) @patch("psycopg2.connect") @@ -806,9 +765,8 @@ def test_slctv2sngrv_impl_if_sql_datatype_is_money(self): """Test selected_value_to_singer_value_impl if sql_datatype is money""" elem = 'foo' og_sql_datatype = 'money' - conn_info = None expected_output = elem - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) @@ -818,9 +776,8 @@ def test_slctv2sngrv_impl_if_timestamp_with_time_zone_as_datatype_and_greater_th # maximum value is hardcoded! and is 9999-12-31 23:59:59.999000 elem = datetime(9999, 12, 31, 23, 59, 59, 999001, tzinfo=timezone.utc) og_sql_datatype = 'timestamp with time zone' - conn_info = None expected_output = logical_replication.FALLBACK_DATETIME - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) @@ -830,9 +787,8 @@ def test_slctv2sngrv_impl_datatype_timestamp_with_time_zone_and_parsed_elm_great # maximum value is hardcoded! and is 9999-12-31 23:59:59.999000 elem = '9999-12-31T23:59:59.9999999+00:00' og_sql_datatype = 'timestamp with time zone' - conn_info = None expected_output = logical_replication.FALLBACK_DATETIME - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) @@ -840,10 +796,9 @@ def test_slctv2sngrv_impl_with_sql_datatype_is_date_with_elm_is_datetime(self): """Test selected_value_to_singer_value_impl if datatype is date and elm type is datetime""" elem = date(2022, 12, 31) og_sql_datatype = 'date' - conn_info = None expected_output = '2022-12-31T00:00:00+00:00' - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) def test_slctv2sngrv_impl_with_sql_datatype_is_date_with_invalid_elem_raises_exception(self): @@ -851,20 +806,17 @@ def test_slctv2sngrv_impl_with_sql_datatype_is_date_with_invalid_elem_raises_exc if datatype is date and elem is invalid""" elem = 'foo' og_sql_datatype = 'date' - conn_info = None self.assertRaises(ValueError, logical_replication.selected_value_to_singer_value_impl, elem, - og_sql_datatype, - conn_info) + og_sql_datatype) def test_slctv2sngrv_impl_with_sql_datatype_is_time_with_time_zone_and_elem_starts_with_24(self): """Test selected_value_to_singer_value_impl if datatype is time with time zone and elem starts with 24""" og_sql_datatype = 'time with time zone' expected_output = '01:12:11' elem = '24:12:11-01' - conn_info = None - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) @@ -873,8 +825,7 @@ def test_slctv2sngrv_impl_with_sql_datatype_is_time_without_time_zone_elem_start og_sql_datatype = 'time without time zone' expected_output = '00:12:11' test_elem = '24:12:11' - conn_info = None - actual_output = logical_replication.selected_value_to_singer_value_impl(test_elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(test_elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) @@ -882,8 +833,7 @@ def test_slctv2sngrv_impl_with_sql_datatype_is_bit(self): """Test selected_value_to_singer_value_impl if datatype is bit""" og_sql_datatype = 'bit' elem = True - conn_info = None - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertTrue(actual_output) @@ -891,9 +841,8 @@ def test_slctv2sngrv_impl_with_sql_datatype_is_int(self): """Test selected_value_to_singer_value_impl if datatype is int""" og_sql_datatype = 'foo' elem = 23 - conn_info = None expected_output = elem - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) @@ -901,30 +850,23 @@ def test_slctv2sngrv_impl_with_sql_datatype_is_boolean(self): """Test selected_value_to_singer_value_impl if datatype is boolean""" og_sql_datatype = 'boolean' elem = 'foo' - conn_info = None expected_output = elem actual_output = logical_replication.selected_value_to_singer_value_impl( elem, - og_sql_datatype, - conn_info + og_sql_datatype ) self.assertEqual(expected_output, actual_output) - @patch("psycopg2.connect") - def test_slctv2sngrv_impl_with_sql_datatype_is_hstore(self, mocked_connect): + def test_slctv2sngrv_impl_with_sql_datatype_is_hstore(self): """Test selected_value_to_singer_value_impl if datatype is hstore""" - mocked_cursor = mocked_connect.return_value.__enter__.return_value.cursor - mocked_fetchone = mocked_cursor.return_value.__enter__.return_value.fetchone - mocked_fetchone.return_value = (['1', '0', '2', '1'],) og_sql_datatype = 'hstore' - hstore_elem = '1=>0,2=>1' + hstore_elem = '"1"=>"0","2"=>"1"' expected_output = {'1': '0', '2': '1'} actual_output = logical_replication.selected_value_to_singer_value_impl( hstore_elem, - og_sql_datatype, - self.conn_info + og_sql_datatype ) self.assertEqual(expected_output, actual_output) @@ -933,9 +875,8 @@ def test_slctv2sngrv_impl_with_sql_datatype_contains_numeric(self): """Test selected_value_to_singer_value_impl if datatype contains numeric""" og_sql_datatype = 'foo numeric bar' elem = '2' - conn_info = None expected_output = decimal.Decimal(elem) - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) @@ -943,9 +884,8 @@ def test_slctv2sngrv_impl_with_float_elem(self): """Test selected_value_to_singer_value_impl if elem is float""" og_sql_datatype = 'foo' elem = 3.14 - conn_info = None expected_output = elem - actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + actual_output = logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_output, actual_output) @@ -953,10 +893,9 @@ def test_slctv2sngrv_impl_raises_exception_with_invalid_type_of_elem(self): """Test selected_value_to_singer_value_impl with invalid type of elem raises an exception""" og_sql_datatype = 'foo' elem = {} - conn_info = None expected_message = f'do not know how to marshall value of type {type(elem)}' with self.assertRaises(Exception) as exp: - logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info) + logical_replication.selected_value_to_singer_value_impl(elem, og_sql_datatype) self.assertEqual(expected_message, str(exp.exception)) From 67398e78b7aedd02ea19069b7fc7bcbb3a6d99b9 Mon Sep 17 00:00:00 2001 From: Nils Mueller Date: Tue, 19 Apr 2022 03:07:31 +0200 Subject: [PATCH 2/2] Local parsing for select arrays Saves a lot of network roundtrips. Could also be done for the other array types, but I didn't have a need for them and not enough time for testing. --- .../sync_strategies/logical_replication.py | 22 ++++++++++++++----- tests/test_logical_replication.py | 4 ++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tap_postgres/sync_strategies/logical_replication.py b/tap_postgres/sync_strategies/logical_replication.py index e5de1aa4..90d4865b 100644 --- a/tap_postgres/sync_strategies/logical_replication.py +++ b/tap_postgres/sync_strategies/logical_replication.py @@ -3,6 +3,7 @@ import decimal import psycopg2 import copy +import csv import json import re import singer @@ -129,18 +130,31 @@ def create_hstore_elem(elem): return HstoreAdapter.parse(elem, None) +def local_array_parsing(elem, cast=None): + elem = [elem[1:-1]] + reader = csv.reader(elem, delimiter=',', escapechar='\\', quotechar='"') + array = next(reader) + array = [None if element.lower() == 'null' else cast(element) if cast else element for element in array] + return array + + def create_array_elem(elem, sql_datatype, conn_info): if elem is None: return None + if sql_datatype == 'text[]': + return local_array_parsing(elem) + elif sql_datatype == 'integer[]': + return local_array_parsing(elem, int) + elif sql_datatype == 'character varying[]': + return local_array_parsing(elem) + with post_db.open_connection(conn_info, False, True) as conn: with conn.cursor() as cur: if sql_datatype == 'bit[]': cast_datatype = 'boolean[]' elif sql_datatype == 'boolean[]': cast_datatype = 'boolean[]' - elif sql_datatype == 'character varying[]': - cast_datatype = 'character varying[]' elif sql_datatype == 'cidr[]': cast_datatype = 'cidr[]' elif sql_datatype == 'citext[]': @@ -151,8 +165,6 @@ def create_array_elem(elem, sql_datatype, conn_info): cast_datatype = 'double precision[]' elif sql_datatype == 'hstore[]': cast_datatype = 'text[]' - elif sql_datatype == 'integer[]': - cast_datatype = 'integer[]' elif sql_datatype == 'inet[]': cast_datatype = 'inet[]' elif sql_datatype == 'json[]': @@ -169,8 +181,6 @@ def create_array_elem(elem, sql_datatype, conn_info): cast_datatype = 'real[]' elif sql_datatype == 'smallint[]': cast_datatype = 'smallint[]' - elif sql_datatype == 'text[]': - cast_datatype = 'text[]' elif sql_datatype in ('time without time zone[]', 'time with time zone[]'): cast_datatype = 'text[]' elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'): diff --git a/tests/test_logical_replication.py b/tests/test_logical_replication.py index 5471c9a3..82cdc63a 100644 --- a/tests/test_logical_replication.py +++ b/tests/test_logical_replication.py @@ -668,13 +668,13 @@ def test_create_array_elem(self, mocked_connect): ('bit[]', {1}, [True]), ('foo', None, None), ('boolean[]', {True}, [True]), - ('character varying[]', {1, 'foo'}, ['1', "'foo'"]), + ('character varying[]', "{1,'foo'}", ['1', "'foo'"]), ('cidr[]', "{127.0.0.1}", ['127.0.0.1/32']), ('citext[]', {1, 'foo'}, ['1', "'foo'"]), ('date[]', '{2022-11-11}', ['2022-11-11']), ('double precision[]', {234.45}, [234.45]), ('hstore[]', {'foo'}, ["'foo'"]), - ('integer[]', {123}, [123]), + ('integer[]', '{123}', [123]), ('inet[]', "{127.0.0.1}", ['127.0.0.1']), ('json[]', {"foo": "bar"}, ["'foo': 'bar'"]), ('jsonb[]', {"foo": "bar"}, ["'foo': 'bar'"]),