diff --git a/tap_postgres/__init__.py b/tap_postgres/__init__.py index 75ed479f..c21a7fa1 100644 --- a/tap_postgres/__init__.py +++ b/tap_postgres/__init__.py @@ -39,9 +39,9 @@ def do_discovery(conn_config): Returns: list of discovered streams """ - with post_db.open_connection(conn_config) as conn: - LOGGER.info("Discovering db %s", conn_config['dbname']) - streams = discover_db(conn, conn_config.get('filter_schemas')) + conn = post_db.open_connection() + LOGGER.info("Discovering db %s", conn_config['dbname']) + streams = discover_db(conn, conn_config.get('filter_schemas')) if len(streams) == 0: raise RuntimeError('0 tables were discovered across the entire cluster') @@ -50,21 +50,21 @@ def do_discovery(conn_config): return streams -def do_sync_full_table(conn_config, stream, state, desired_columns, md_map): +def do_sync_full_table(stream, state, desired_columns, md_map): """ Runs full table sync """ LOGGER.info("Stream %s is using full_table replication", stream['tap_stream_id']) sync_common.send_schema_message(stream, []) if md_map.get((), {}).get('is-view'): - state = full_table.sync_view(conn_config, stream, state, desired_columns, md_map) + state = full_table.sync_view(stream, state, desired_columns, md_map) else: - state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map) + state = full_table.sync_table(stream, state, desired_columns, md_map) return state # Possible state keys: replication_key, replication_key_value, version -def do_sync_incremental(conn_config, stream, state, desired_columns, md_map): +def do_sync_incremental(stream, state, desired_columns, md_map): """ Runs Incremental sync """ @@ -82,7 +82,7 @@ def do_sync_incremental(conn_config, stream, state, desired_columns, md_map): state = singer.write_bookmark(state, stream['tap_stream_id'], 'replication_key', replication_key) sync_common.send_schema_message(stream, [replication_key]) - state = incremental.sync_table(conn_config, stream, state, desired_columns, md_map) + state = incremental.sync_table(stream, state, desired_columns, md_map) return state @@ -164,27 +164,27 @@ def sync_traditional_stream(conn_config, stream, state, sync_method, end_lsn): LOGGER.warning('There are no columns selected for stream %s, skipping it', stream['tap_stream_id']) return state - register_type_adapters(conn_config) + register_type_adapters() if sync_method == 'full': state = singer.set_currently_syncing(state, stream['tap_stream_id']) - state = do_sync_full_table(conn_config, stream, state, desired_columns, md_map) + state = do_sync_full_table(stream, state, desired_columns, md_map) elif sync_method == 'incremental': state = singer.set_currently_syncing(state, stream['tap_stream_id']) - state = do_sync_incremental(conn_config, stream, state, desired_columns, md_map) + state = do_sync_incremental(stream, state, desired_columns, md_map) elif sync_method == 'logical_initial': state = singer.set_currently_syncing(state, stream['tap_stream_id']) LOGGER.info("Performing initial full table sync") state = singer.write_bookmark(state, stream['tap_stream_id'], 'lsn', end_lsn) sync_common.send_schema_message(stream, []) - state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map) + state = full_table.sync_table(stream, state, desired_columns, md_map) state = singer.write_bookmark(state, stream['tap_stream_id'], 'xmin', None) elif sync_method == 'logical_initial_interrupted': state = singer.set_currently_syncing(state, stream['tap_stream_id']) LOGGER.info("Initial stage of full table sync was interrupted. resuming...") sync_common.send_schema_message(stream, []) - state = full_table.sync_table(conn_config, stream, state, desired_columns, md_map) + state = full_table.sync_table(stream, state, desired_columns, md_map) else: raise Exception("unknown sync method {} for stream {}".format(sync_method, stream['tap_stream_id'])) @@ -222,53 +222,53 @@ def sync_logical_streams(conn_config, logical_streams, state, end_lsn, state_fil return state -def register_type_adapters(conn_config): +def register_type_adapters(): """ //todo doc needed """ - with post_db.open_connection(conn_config) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - # citext[] - cur.execute("SELECT typarray FROM pg_type where typname = 'citext'") - citext_array_oid = cur.fetchone() - if citext_array_oid: - psycopg2.extensions.register_type( - psycopg2.extensions.new_array_type( - (citext_array_oid[0],), 'CITEXT[]', psycopg2.STRING)) - - # bit[] - cur.execute("SELECT typarray FROM pg_type where typname = 'bit'") - bit_array_oid = cur.fetchone()[0] + conn = post_db.open_connection() + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + # citext[] + cur.execute("SELECT typarray FROM pg_type where typname = 'citext'") + citext_array_oid = cur.fetchone() + if citext_array_oid: psycopg2.extensions.register_type( psycopg2.extensions.new_array_type( - (bit_array_oid,), 'BIT[]', psycopg2.STRING)) - - # UUID[] - cur.execute("SELECT typarray FROM pg_type where typname = 'uuid'") - uuid_array_oid = cur.fetchone()[0] - psycopg2.extensions.register_type( - psycopg2.extensions.new_array_type( - (uuid_array_oid,), 'UUID[]', psycopg2.STRING)) - - # money[] - cur.execute("SELECT typarray FROM pg_type where typname = 'money'") - money_array_oid = cur.fetchone()[0] + (citext_array_oid[0],), 'CITEXT[]', psycopg2.STRING)) + + # bit[] + cur.execute("SELECT typarray FROM pg_type where typname = 'bit'") + bit_array_oid = cur.fetchone()[0] + psycopg2.extensions.register_type( + psycopg2.extensions.new_array_type( + (bit_array_oid,), 'BIT[]', psycopg2.STRING)) + + # UUID[] + cur.execute("SELECT typarray FROM pg_type where typname = 'uuid'") + uuid_array_oid = cur.fetchone()[0] + psycopg2.extensions.register_type( + psycopg2.extensions.new_array_type( + (uuid_array_oid,), 'UUID[]', psycopg2.STRING)) + + # money[] + cur.execute("SELECT typarray FROM pg_type where typname = 'money'") + money_array_oid = cur.fetchone()[0] + psycopg2.extensions.register_type( + psycopg2.extensions.new_array_type( + (money_array_oid,), 'MONEY[]', psycopg2.STRING)) + + # json and jsonb + # pylint: disable=unnecessary-lambda + psycopg2.extras.register_default_json(loads=lambda x: str(x)) + psycopg2.extras.register_default_jsonb(loads=lambda x: str(x)) + + # enum[]'s + cur.execute("SELECT distinct(t.typarray) FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid") + for oid in cur.fetchall(): + enum_oid = oid[0] psycopg2.extensions.register_type( psycopg2.extensions.new_array_type( - (money_array_oid,), 'MONEY[]', psycopg2.STRING)) - - # json and jsonb - # pylint: disable=unnecessary-lambda - psycopg2.extras.register_default_json(loads=lambda x: str(x)) - psycopg2.extras.register_default_jsonb(loads=lambda x: str(x)) - - # enum[]'s - cur.execute("SELECT distinct(t.typarray) FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid") - for oid in cur.fetchall(): - enum_oid = oid[0] - psycopg2.extensions.register_type( - psycopg2.extensions.new_array_type( - (enum_oid,), 'ENUM_{}[]'.format(enum_oid), psycopg2.STRING)) + (enum_oid,), 'ENUM_{}[]'.format(enum_oid), psycopg2.STRING)) def do_sync(conn_config, catalog, default_replication_method, state, state_file=None): @@ -281,7 +281,7 @@ def do_sync(conn_config, catalog, default_replication_method, state, state_file= LOGGER.info("Selected streams: %s ", [s['tap_stream_id'] for s in streams]) if any_logical_streams(streams, default_replication_method): # Use of logical replication requires fetching an lsn - end_lsn = logical_replication.fetch_current_lsn(conn_config) + end_lsn = logical_replication.fetch_current_lsn() LOGGER.debug("end_lsn = %s ", end_lsn) else: end_lsn = None diff --git a/tap_postgres/db.py b/tap_postgres/db.py index c7711c42..9ee80998 100644 --- a/tap_postgres/db.py +++ b/tap_postgres/db.py @@ -11,6 +11,8 @@ from typing import List from dateutil.parser import parse +from tap_postgres.postgres import Postgres + LOGGER = singer.get_logger('tap_postgres') CURSOR_ITER_SIZE = 20000 @@ -38,26 +40,9 @@ def fully_qualified_table_name(schema, table): return '"{}"."{}"'.format(canonicalize_identifier(schema), canonicalize_identifier(table)) -def open_connection(conn_config, logical_replication=False): - cfg = { - 'application_name': 'pipelinewise', - 'host': conn_config['host'], - 'dbname': conn_config['dbname'], - 'user': conn_config['user'], - 'password': conn_config['password'], - 'port': conn_config['port'], - 'connect_timeout': 30 - } - - if conn_config.get('sslmode'): - cfg['sslmode'] = conn_config['sslmode'] - - if logical_replication: - cfg['connection_factory'] = psycopg2.extras.LogicalReplicationConnection +def open_connection(logical_replication=False): + return Postgres.get_instance().connect(logical_replication) - conn = psycopg2.connect(**cfg) - - return conn def prepare_columns_for_select_sql(c, md_map): column_name = ' "{}" '.format(canonicalize_identifier(c)) @@ -191,14 +176,14 @@ def selected_row_to_singer_message(stream, row, version, columns, time_extracted time_extracted=time_extracted) -def hstore_available(conn_info): - with open_connection(conn_info) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: - cur.execute(""" SELECT installed_version FROM pg_available_extensions WHERE name = 'hstore' """) - res = cur.fetchone() - if res and res[0]: - return True - return False +def hstore_available(): + conn = open_connection() + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: + cur.execute(""" SELECT installed_version FROM pg_available_extensions WHERE name = 'hstore' """) + res = cur.fetchone() + if res and res[0]: + return True + return False def compute_tap_stream_id(schema_name, table_name): diff --git a/tap_postgres/postgres.py b/tap_postgres/postgres.py new file mode 100644 index 00000000..b859cdef --- /dev/null +++ b/tap_postgres/postgres.py @@ -0,0 +1,56 @@ +from singer import utils + +import psycopg2 +import psycopg2.extras + + +# pylint: disable=missing-class-docstring,missing-function-docstring +class Postgres: + __instance = None + + @staticmethod + def get_instance(): + if Postgres.__instance is None: + Postgres() + + return Postgres.__instance + + @staticmethod + def get_configuration(logical_replication): + args = utils.parse_args({}) + conn_config = args.config + + cfg = { + 'application_name': 'pipelinewise', + 'host': conn_config['host'], + 'dbname': conn_config['dbname'], + 'user': conn_config['user'], + 'password': conn_config['password'], + 'port': conn_config['port'], + 'connect_timeout': 30 + } + + if conn_config.get('sslmode'): + cfg['sslmode'] = conn_config['sslmode'] + + if logical_replication: + cfg['connection_factory'] = psycopg2.extras.LogicalReplicationConnection + + return cfg + + + def __init__(self): + if Postgres.__instance is not None: + raise Exception("This class is a singleton!") + + Postgres.__instance = self + self.connections = {"logical": None, "transactional": None} + + def connect(self, logical_replication=False): + connection_type = "logical" if logical_replication else "transactional" + + if not self.connections[connection_type] or self.connections[connection_type].closed: + config = Postgres.get_configuration(logical_replication) + self.connections[connection_type] = psycopg2.connect(**config) + + return self.connections[connection_type] diff --git a/tap_postgres/stream_utils.py b/tap_postgres/stream_utils.py index 84a4f46e..b6434c4a 100644 --- a/tap_postgres/stream_utils.py +++ b/tap_postgres/stream_utils.py @@ -66,33 +66,33 @@ def refresh_streams_schema(conn_config: Dict, streams: List[Dict]): LOGGER.debug('Current streams schemas %s', streams) # Run discovery to get the streams most up to date json schemas - with open_connection(conn_config) as conn: - new_discovery = { - stream['tap_stream_id']: stream - for stream in discover_db(conn, conn_config.get('filter_schemas'), [st['table_name'] for st in streams]) - } - - LOGGER.debug('New discovery schemas %s', new_discovery) - - # For every stream dictionary, update the schema and metadata from the new discovery - for idx, stream in enumerate(streams): - # update schema - streams[idx]['schema'] = copy.deepcopy(new_discovery[stream['tap_stream_id']]['schema']) - - # Update metadata - # - # 1st step: new discovery doesn't contain non-discoverable metadata: e.g replication method & key, selected - # so let's copy those from the original stream object - md_map = metadata.to_map(stream['metadata']) - meta = md_map.get(()) - - for idx_met, metadatum in enumerate(new_discovery[stream['tap_stream_id']]['metadata']): - if not metadatum['breadcrumb']: - meta.update(new_discovery[stream['tap_stream_id']]['metadata'][idx_met]['metadata']) - new_discovery[stream['tap_stream_id']]['metadata'][idx_met]['metadata'] = meta - - # 2nd step: now copy all the metadata from the updated new discovery to the original stream - streams[idx]['metadata'] = copy.deepcopy(new_discovery[stream['tap_stream_id']]['metadata']) + conn = open_connection() + new_discovery = { + stream['tap_stream_id']: stream + for stream in discover_db(conn, conn_config.get('filter_schemas'), [st['table_name'] for st in streams]) + } + + LOGGER.debug('New discovery schemas %s', new_discovery) + + # For every stream dictionary, update the schema and metadata from the new discovery + for idx, stream in enumerate(streams): + # update schema + streams[idx]['schema'] = copy.deepcopy(new_discovery[stream['tap_stream_id']]['schema']) + + # Update metadata + # + # 1st step: new discovery doesn't contain non-discoverable metadata: e.g replication method & key, selected + # so let's copy those from the original stream object + md_map = metadata.to_map(stream['metadata']) + meta = md_map.get(()) + + for idx_met, metadatum in enumerate(new_discovery[stream['tap_stream_id']]['metadata']): + if not metadatum['breadcrumb']: + meta.update(new_discovery[stream['tap_stream_id']]['metadata'][idx_met]['metadata']) + new_discovery[stream['tap_stream_id']]['metadata'][idx_met]['metadata'] = meta + + # 2nd step: now copy all the metadata from the updated new discovery to the original stream + streams[idx]['metadata'] = copy.deepcopy(new_discovery[stream['tap_stream_id']]['metadata']) LOGGER.debug('Updated streams schemas %s', streams) diff --git a/tap_postgres/sync_strategies/full_table.py b/tap_postgres/sync_strategies/full_table.py index 0708092f..a7ea7736 100644 --- a/tap_postgres/sync_strategies/full_table.py +++ b/tap_postgres/sync_strategies/full_table.py @@ -16,7 +16,7 @@ # pylint: disable=invalid-name,missing-function-docstring,too-many-locals,duplicate-code -def sync_view(conn_info, stream, state, desired_columns, md_map): +def sync_view(stream, state, desired_columns, md_map): time_extracted = utils.now() # before writing the table version to state, check if we had one to begin with @@ -41,30 +41,30 @@ def sync_view(conn_info, stream, state, desired_columns, md_map): singer.write_message(activate_version_message) with metrics.record_counter(None) as counter: - with post_db.open_connection(conn_info) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: - cur.itersize = post_db.CURSOR_ITER_SIZE - select_sql = 'SELECT {} FROM {}'.format(','.join(escaped_columns), - post_db.fully_qualified_table_name(schema_name, - stream['table_name'])) - - LOGGER.info("select %s with itersize %s", select_sql, cur.itersize) - cur.execute(select_sql) - - rows_saved = 0 - for rec in cur: - record_message = post_db.selected_row_to_singer_message(stream, - rec, - nascent_stream_version, - desired_columns, - time_extracted, - md_map) - singer.write_message(record_message) - rows_saved = rows_saved + 1 - if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: - singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) - - counter.increment() + conn = post_db.open_connection() + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: + cur.itersize = post_db.CURSOR_ITER_SIZE + select_sql = 'SELECT {} FROM {}'.format(','.join(escaped_columns), + post_db.fully_qualified_table_name(schema_name, + stream['table_name'])) + + LOGGER.info("select %s with itersize %s", select_sql, cur.itersize) + cur.execute(select_sql) + + rows_saved = 0 + for rec in cur: + record_message = post_db.selected_row_to_singer_message(stream, + rec, + nascent_stream_version, + desired_columns, + time_extracted, + md_map) + singer.write_message(record_message) + rows_saved = rows_saved + 1 + if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + counter.increment() # always send the activate version whether first run or subsequent singer.write_message(activate_version_message) @@ -73,7 +73,7 @@ def sync_view(conn_info, stream, state, desired_columns, md_map): # pylint: disable=too-many-statements,duplicate-code -def sync_table(conn_info, stream, state, desired_columns, md_map): +def sync_table(stream, state, desired_columns, md_map): time_extracted = utils.now() # before writing the table version to state, check if we had one to begin with @@ -103,63 +103,63 @@ def sync_table(conn_info, stream, state, desired_columns, md_map): if first_run: singer.write_message(activate_version_message) - hstore_available = post_db.hstore_available(conn_info) + hstore_available = post_db.hstore_available() with metrics.record_counter(None) as counter: - with post_db.open_connection(conn_info) as conn: - - # Client side character encoding defaults to the value in postgresql.conf under client_encoding. - # The server / db can also have its own configred encoding. - with conn.cursor() as cur: - cur.execute("show server_encoding") - LOGGER.info("Current Server Encoding: %s", cur.fetchone()[0]) - cur.execute("show client_encoding") - LOGGER.info("Current Client Encoding: %s", cur.fetchone()[0]) - - if hstore_available: - LOGGER.info("hstore is available") - psycopg2.extras.register_hstore(conn) + conn = post_db.open_connection() + + # Client side character encoding defaults to the value in postgresql.conf under client_encoding. + # The server / db can also have its own configred encoding. + with conn.cursor() as cur: + cur.execute("show server_encoding") + LOGGER.info("Current Server Encoding: %s", cur.fetchone()[0]) + cur.execute("show client_encoding") + LOGGER.info("Current Client Encoding: %s", cur.fetchone()[0]) + + if hstore_available: + LOGGER.info("hstore is available") + psycopg2.extras.register_hstore(conn) + else: + LOGGER.info("hstore is UNavailable") + + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: + cur.itersize = post_db.CURSOR_ITER_SIZE + + fq_table_name = post_db.fully_qualified_table_name(schema_name, stream['table_name']) + xmin = singer.get_bookmark(state, stream['tap_stream_id'], 'xmin') + if xmin: + LOGGER.info("Resuming Full Table replication %s from xmin %s", nascent_stream_version, xmin) + select_sql = """SELECT {}, xmin::text::bigint + FROM {} where age(xmin::xid) <= age('{}'::xid) + ORDER BY xmin::text ASC""".format(','.join(escaped_columns), + fq_table_name, + xmin) else: - LOGGER.info("hstore is UNavailable") - - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='stitch_cursor') as cur: - cur.itersize = post_db.CURSOR_ITER_SIZE - - fq_table_name = post_db.fully_qualified_table_name(schema_name, stream['table_name']) - xmin = singer.get_bookmark(state, stream['tap_stream_id'], 'xmin') - if xmin: - LOGGER.info("Resuming Full Table replication %s from xmin %s", nascent_stream_version, xmin) - select_sql = """SELECT {}, xmin::text::bigint - FROM {} where age(xmin::xid) <= age('{}'::xid) - ORDER BY xmin::text ASC""".format(','.join(escaped_columns), - fq_table_name, - xmin) - else: - LOGGER.info("Beginning new Full Table replication %s", nascent_stream_version) - select_sql = """SELECT {}, xmin::text::bigint - FROM {} - ORDER BY xmin::text ASC""".format(','.join(escaped_columns), - fq_table_name) - - LOGGER.info("select %s with itersize %s", select_sql, cur.itersize) - cur.execute(select_sql) - - rows_saved = 0 - for rec in cur: - xmin = rec['xmin'] - rec = rec[:-1] - record_message = post_db.selected_row_to_singer_message(stream, - rec, - nascent_stream_version, - desired_columns, - time_extracted, - md_map) - singer.write_message(record_message) - state = singer.write_bookmark(state, stream['tap_stream_id'], 'xmin', xmin) - rows_saved = rows_saved + 1 - if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: - singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) - - counter.increment() + LOGGER.info("Beginning new Full Table replication %s", nascent_stream_version) + select_sql = """SELECT {}, xmin::text::bigint + FROM {} + ORDER BY xmin::text ASC""".format(','.join(escaped_columns), + fq_table_name) + + LOGGER.info("select %s with itersize %s", select_sql, cur.itersize) + cur.execute(select_sql) + + rows_saved = 0 + for rec in cur: + xmin = rec['xmin'] + rec = rec[:-1] + record_message = post_db.selected_row_to_singer_message(stream, + rec, + nascent_stream_version, + desired_columns, + time_extracted, + md_map) + singer.write_message(record_message) + state = singer.write_bookmark(state, stream['tap_stream_id'], 'xmin', xmin) + rows_saved = rows_saved + 1 + if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + counter.increment() # once we have completed the full table replication, discard the xmin bookmark. # the xmin bookmark only comes into play when a full table replication is interrupted diff --git a/tap_postgres/sync_strategies/incremental.py b/tap_postgres/sync_strategies/incremental.py index 25c3dbc4..d5cd86f7 100644 --- a/tap_postgres/sync_strategies/incremental.py +++ b/tap_postgres/sync_strategies/incremental.py @@ -17,21 +17,21 @@ # pylint: disable=invalid-name,missing-function-docstring -def fetch_max_replication_key(conn_config, replication_key, schema_name, table_name): - with post_db.open_connection(conn_config, False) as conn: - with conn.cursor() as cur: - max_key_sql = """SELECT max({}) - FROM {}""".format(post_db.prepare_columns_sql(replication_key), - post_db.fully_qualified_table_name(schema_name, table_name)) - LOGGER.info("determine max replication key value: %s", max_key_sql) - cur.execute(max_key_sql) - max_key = cur.fetchone()[0] - LOGGER.info("max replication key value: %s", max_key) - return max_key +def fetch_max_replication_key(replication_key, schema_name, table_name): + conn = post_db.open_connection() + with conn.cursor() as cur: + max_key_sql = """SELECT max({}) + FROM {}""".format(post_db.prepare_columns_sql(replication_key), + post_db.fully_qualified_table_name(schema_name, table_name)) + LOGGER.info("determine max replication key value: %s", max_key_sql) + cur.execute(max_key_sql) + max_key = cur.fetchone()[0] + LOGGER.info("max replication key value: %s", max_key) + return max_key # pylint: disable=too-many-locals -def sync_table(conn_info, stream, state, desired_columns, md_map): +def sync_table(stream, state, desired_columns, md_map): time_extracted = utils.now() stream_version = singer.get_bookmark(state, stream['tap_stream_id'], 'version') @@ -59,75 +59,75 @@ def sync_table(conn_info, stream, state, desired_columns, md_map): replication_key_value = singer.get_bookmark(state, stream['tap_stream_id'], 'replication_key_value') replication_key_sql_datatype = md_map.get(('properties', replication_key)).get('sql-datatype') - hstore_available = post_db.hstore_available(conn_info) + hstore_available = post_db.hstore_available() with metrics.record_counter(None) as counter: - with post_db.open_connection(conn_info) as conn: - - # Client side character encoding defaults to the value in postgresql.conf under client_encoding. - # The server / db can also have its own configred encoding. - with conn.cursor() as cur: - cur.execute("show server_encoding") - LOGGER.info("Current Server Encoding: %s", cur.fetchone()[0]) - cur.execute("show client_encoding") - LOGGER.info("Current Client Encoding: %s", cur.fetchone()[0]) - - if hstore_available: - LOGGER.info("hstore is available") - psycopg2.extras.register_hstore(conn) + conn = post_db.open_connection() + + # Client side character encoding defaults to the value in postgresql.conf under client_encoding. + # The server / db can also have its own configred encoding. + with conn.cursor() as cur: + cur.execute("show server_encoding") + LOGGER.info("Current Server Encoding: %s", cur.fetchone()[0]) + cur.execute("show client_encoding") + LOGGER.info("Current Client Encoding: %s", cur.fetchone()[0]) + + if hstore_available: + LOGGER.info("hstore is available") + psycopg2.extras.register_hstore(conn) + else: + LOGGER.info("hstore is UNavailable") + + with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='pipelinewise') as cur: + cur.itersize = post_db.CURSOR_ITER_SIZE + LOGGER.info("Beginning new incremental replication sync %s", stream_version) + if replication_key_value: + select_sql = """SELECT {} + FROM {} + WHERE {} >= '{}'::{} + ORDER BY {} ASC""".format(','.join(escaped_columns), + post_db.fully_qualified_table_name(schema_name, + stream['table_name']), + post_db.prepare_columns_sql(replication_key), + replication_key_value, + replication_key_sql_datatype, + post_db.prepare_columns_sql(replication_key)) else: - LOGGER.info("hstore is UNavailable") - - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='pipelinewise') as cur: - cur.itersize = post_db.CURSOR_ITER_SIZE - LOGGER.info("Beginning new incremental replication sync %s", stream_version) - if replication_key_value: - select_sql = """SELECT {} - FROM {} - WHERE {} >= '{}'::{} - ORDER BY {} ASC""".format(','.join(escaped_columns), - post_db.fully_qualified_table_name(schema_name, - stream['table_name']), - post_db.prepare_columns_sql(replication_key), - replication_key_value, - replication_key_sql_datatype, - post_db.prepare_columns_sql(replication_key)) - else: - #if not replication_key_value - select_sql = """SELECT {} - FROM {} - ORDER BY {} ASC""".format(','.join(escaped_columns), - post_db.fully_qualified_table_name(schema_name, - stream['table_name']), - post_db.prepare_columns_sql(replication_key)) - - LOGGER.info('select statement: %s with itersize %s', select_sql, cur.itersize) - cur.execute(select_sql) - - rows_saved = 0 - - for rec in cur: - record_message = post_db.selected_row_to_singer_message(stream, - rec, - stream_version, - desired_columns, - time_extracted, - md_map) - - singer.write_message(record_message) - rows_saved = rows_saved + 1 - - #Picking a replication_key with NULL values will result in it ALWAYS been synced which is not great - #event worse would be allowing the NULL value to enter into the state - if record_message.record[replication_key] is not None: - state = singer.write_bookmark(state, - stream['tap_stream_id'], - 'replication_key_value', - record_message.record[replication_key]) - - - if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: - singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) - - counter.increment() + #if not replication_key_value + select_sql = """SELECT {} + FROM {} + ORDER BY {} ASC""".format(','.join(escaped_columns), + post_db.fully_qualified_table_name(schema_name, + stream['table_name']), + post_db.prepare_columns_sql(replication_key)) + + LOGGER.info('select statement: %s with itersize %s', select_sql, cur.itersize) + cur.execute(select_sql) + + rows_saved = 0 + + for rec in cur: + record_message = post_db.selected_row_to_singer_message(stream, + rec, + stream_version, + desired_columns, + time_extracted, + md_map) + + singer.write_message(record_message) + rows_saved = rows_saved + 1 + + #Picking a replication_key with NULL values will result in it ALWAYS been synced which is not great + #event worse would be allowing the NULL value to enter into the state + if record_message.record[replication_key] is not None: + state = singer.write_bookmark(state, + stream['tap_stream_id'], + 'replication_key_value', + record_message.record[replication_key]) + + + if rows_saved % UPDATE_BOOKMARK_PERIOD == 0: + singer.write_message(singer.StateMessage(value=copy.deepcopy(state))) + + counter.increment() return state diff --git a/tap_postgres/sync_strategies/logical_replication.py b/tap_postgres/sync_strategies/logical_replication.py index a1c70fe0..f17e3b7c 100644 --- a/tap_postgres/sync_strategies/logical_replication.py +++ b/tap_postgres/sync_strategies/logical_replication.py @@ -32,11 +32,11 @@ class UnsupportedPayloadKindError(Exception): # pylint: disable=invalid-name,missing-function-docstring,too-many-branches,too-many-statements,too-many-arguments -def get_pg_version(conn_info): - with post_db.open_connection(conn_info, False) as conn: - with conn.cursor() as cur: - cur.execute("SELECT setting::int AS version FROM pg_settings WHERE name='server_version_num'") - version = cur.fetchone()[0] +def get_pg_version(): + conn = post_db.open_connection() + with conn.cursor() as cur: + cur.execute("SELECT setting::int AS version FROM pg_settings WHERE name='server_version_num'") + version = cur.fetchone()[0] LOGGER.debug('Detected PostgreSQL version: %s', version) return version @@ -75,8 +75,8 @@ def int_to_lsn(lsni): # pylint: disable=chained-comparison -def fetch_current_lsn(conn_config): - version = get_pg_version(conn_config) +def fetch_current_lsn(): + version = get_pg_version() # Make sure PostgreSQL version is 9.4 or higher # Do not allow minor versions with PostgreSQL BUG #15114 if (version >= 110000) and (version < 110002): @@ -92,18 +92,18 @@ def fetch_current_lsn(conn_config): if version < 90400: raise Exception('Logical replication not supported before PostgreSQL 9.4') - with post_db.open_connection(conn_config, False) as conn: - with conn.cursor() as cur: - # Use version specific lsn command - if version >= 100000: - cur.execute("SELECT pg_current_wal_lsn() AS current_lsn") - elif version >= 90400: - cur.execute("SELECT pg_current_xlog_location() AS current_lsn") - else: - raise Exception('Logical replication not supported before PostgreSQL 9.4') + conn = post_db.open_connection() + with conn.cursor() as cur: + # Use version specific lsn command + if version >= 100000: + cur.execute("SELECT pg_current_wal_lsn() AS current_lsn") + elif version >= 90400: + cur.execute("SELECT pg_current_xlog_location() AS current_lsn") + else: + raise Exception('Logical replication not supported before PostgreSQL 9.4') - current_lsn = cur.fetchone()[0] - return lsn_to_int(current_lsn) + current_lsn = cur.fetchone()[0] + return lsn_to_int(current_lsn) def add_automatic_properties(stream, debug_lsn: bool = False): @@ -136,77 +136,77 @@ def create_hstore_elem_query(elem): return sql.SQL("SELECT hstore_to_array({})").format(sql.Literal(elem)) -def create_hstore_elem(conn_info, elem): - with post_db.open_connection(conn_info) as conn: - with conn.cursor() as cur: - query = create_hstore_elem_query(elem) - cur.execute(query) - res = cur.fetchone()[0] - hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {}) - return hstore_elem +def create_hstore_elem(elem): + conn = post_db.open_connection() + with conn.cursor() as cur: + query = create_hstore_elem_query(elem) + cur.execute(query) + res = cur.fetchone()[0] + hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {}) + return hstore_elem -def create_array_elem(elem, sql_datatype, conn_info): +def create_array_elem(elem, sql_datatype): if elem is None: return None - with post_db.open_connection(conn_info) as conn: - with conn.cursor() as cur: - if sql_datatype == 'bit[]': - cast_datatype = 'boolean[]' - elif sql_datatype == 'boolean[]': - cast_datatype = 'boolean[]' - elif sql_datatype == 'character varying[]': - cast_datatype = 'character varying[]' - elif sql_datatype == 'cidr[]': - cast_datatype = 'cidr[]' - elif sql_datatype == 'citext[]': - cast_datatype = 'text[]' - elif sql_datatype == 'date[]': - cast_datatype = 'text[]' - elif sql_datatype == 'double precision[]': - cast_datatype = 'double precision[]' - elif sql_datatype == 'hstore[]': - cast_datatype = 'text[]' - elif sql_datatype == 'integer[]': - cast_datatype = 'integer[]' - elif sql_datatype == 'inet[]': - cast_datatype = 'inet[]' - elif sql_datatype == 'json[]': - cast_datatype = 'text[]' - elif sql_datatype == 'jsonb[]': - cast_datatype = 'text[]' - elif sql_datatype == 'macaddr[]': - cast_datatype = 'macaddr[]' - elif sql_datatype == 'money[]': - cast_datatype = 'text[]' - elif sql_datatype == 'numeric[]': - cast_datatype = 'text[]' - elif sql_datatype == 'real[]': - cast_datatype = 'real[]' - elif sql_datatype == 'smallint[]': - cast_datatype = 'smallint[]' - elif sql_datatype == 'text[]': - cast_datatype = 'text[]' - elif sql_datatype in ('time without time zone[]', 'time with time zone[]'): - cast_datatype = 'text[]' - elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'): - cast_datatype = 'text[]' - elif sql_datatype == 'uuid[]': - cast_datatype = 'text[]' - - else: - # custom datatypes like enums - cast_datatype = 'text[]' - - sql_stmt = """SELECT $stitch_quote${}$stitch_quote$::{}""".format(elem, cast_datatype) - cur.execute(sql_stmt) - res = cur.fetchone()[0] - return res + conn = post_db.open_connection() + with conn.cursor() as cur: + if sql_datatype == 'bit[]': + cast_datatype = 'boolean[]' + elif sql_datatype == 'boolean[]': + cast_datatype = 'boolean[]' + elif sql_datatype == 'character varying[]': + cast_datatype = 'character varying[]' + elif sql_datatype == 'cidr[]': + cast_datatype = 'cidr[]' + elif sql_datatype == 'citext[]': + cast_datatype = 'text[]' + elif sql_datatype == 'date[]': + cast_datatype = 'text[]' + elif sql_datatype == 'double precision[]': + cast_datatype = 'double precision[]' + elif sql_datatype == 'hstore[]': + cast_datatype = 'text[]' + elif sql_datatype == 'integer[]': + cast_datatype = 'integer[]' + elif sql_datatype == 'inet[]': + cast_datatype = 'inet[]' + elif sql_datatype == 'json[]': + cast_datatype = 'text[]' + elif sql_datatype == 'jsonb[]': + cast_datatype = 'text[]' + elif sql_datatype == 'macaddr[]': + cast_datatype = 'macaddr[]' + elif sql_datatype == 'money[]': + cast_datatype = 'text[]' + elif sql_datatype == 'numeric[]': + cast_datatype = 'text[]' + elif sql_datatype == 'real[]': + cast_datatype = 'real[]' + elif sql_datatype == 'smallint[]': + cast_datatype = 'smallint[]' + elif sql_datatype == 'text[]': + cast_datatype = 'text[]' + elif sql_datatype in ('time without time zone[]', 'time with time zone[]'): + cast_datatype = 'text[]' + elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'): + cast_datatype = 'text[]' + elif sql_datatype == 'uuid[]': + cast_datatype = 'text[]' + + else: + # custom datatypes like enums + cast_datatype = 'text[]' + + sql_stmt = """SELECT $stitch_quote${}$stitch_quote$::{}""".format(elem, cast_datatype) + cur.execute(sql_stmt) + res = cur.fetchone()[0] + return res # pylint: disable=too-many-branches,too-many-nested-blocks,too-many-return-statements -def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): +def selected_value_to_singer_value_impl(elem, og_sql_datatype): sql_datatype = og_sql_datatype.replace('[]', '') if elem is None: @@ -314,7 +314,7 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info): if sql_datatype == 'boolean': return elem if sql_datatype == 'hstore': - return create_hstore_elem(conn_info, elem) + return create_hstore_elem(elem) if 'numeric' in sql_datatype: return decimal.Decimal(elem) if isinstance(elem, int): @@ -331,17 +331,17 @@ def selected_array_to_singer_value(elem, sql_datatype, conn_info): if isinstance(elem, list): return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), elem)) - return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info) + return selected_value_to_singer_value_impl(elem, sql_datatype) def selected_value_to_singer_value(elem, sql_datatype, conn_info): # are we dealing with an array? if sql_datatype.find('[]') > 0: - cleaned_elem = create_array_elem(elem, sql_datatype, conn_info) + cleaned_elem = create_array_elem(elem, sql_datatype) return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), (cleaned_elem or []))) - return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info) + return selected_value_to_singer_value_impl(elem, sql_datatype) def row_to_singer_message(stream, row, version, columns, time_extracted, md_map, conn_info): @@ -508,9 +508,9 @@ def locate_replication_slot_by_cur(cursor, dbname, tap_id=None): def locate_replication_slot(conn_info): - with post_db.open_connection(conn_info, False) as conn: - with conn.cursor() as cur: - return locate_replication_slot_by_cur(cur, conn_info['dbname'], conn_info['tap_id']) + conn = post_db.open_connection() + with conn.cursor() as cur: + return locate_replication_slot_by_cur(cur, conn_info['dbname'], conn_info['tap_id']) # pylint: disable=anomalous-backslash-in-string @@ -564,10 +564,10 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file): for s in logical_streams: sync_common.send_schema_message(s, ['lsn']) - version = get_pg_version(conn_info) + version = get_pg_version() # Create replication connection and cursor - conn = post_db.open_connection(conn_info, True) + conn = post_db.open_connection(True) cur = conn.cursor() # Set session wal_sender_timeout for PG12 and above