From 4a88945fd47bff30ee3536099161a00c71f99d7b Mon Sep 17 00:00:00 2001 From: Peter Kosztolanyi Date: Mon, 20 Apr 2020 15:11:15 +0100 Subject: [PATCH] [AP-655] Handle to LOG_BASED replicate muliple databases by multiple taps (#50) --- tap_postgres/__init__.py | 64 +++++++------------ .../sync_strategies/logical_replication.py | 22 ++++++- tests/test_logical_replication.py | 23 +++++++ 3 files changed, 66 insertions(+), 43 deletions(-) diff --git a/tap_postgres/__init__.py b/tap_postgres/__init__.py index 30a5d902..2b32a950 100644 --- a/tap_postgres/__init__.py +++ b/tap_postgres/__init__.py @@ -416,43 +416,19 @@ def attempt_connection_to_db(conn_config, dbname): def dump_catalog(all_streams): - json.dump({'streams' : all_streams}, sys.stdout, indent=2) + json.dump({'streams': all_streams}, sys.stdout, indent=2) def do_discovery(conn_config): - all_streams = [] - with post_db.open_connection(conn_config) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='pipelinewise') as cur: - cur.itersize = post_db.cursor_iter_size - sql = """SELECT datname - FROM pg_database - WHERE datistemplate = false - AND datname != 'rdsadmin'""" - - if conn_config.get('filter_dbs'): - sql = post_db.filter_dbs_sql_clause(sql, conn_config['filter_dbs']) - - LOGGER.info("Running DB discovery: %s with itersize %s", sql, cur.itersize) - cur.execute(sql) - found_dbs = (row[0] for row in cur.fetchall()) - - filter_dbs = filter(lambda dbname: attempt_connection_to_db(conn_config, dbname), found_dbs) - - for db_row in filter_dbs: - dbname = db_row - LOGGER.info("Discovering db %s", dbname) - conn_config['dbname'] = dbname - with post_db.open_connection(conn_config) as conn: - db_streams = discover_db(conn, conn_config.get('filter_schemas')) - all_streams = all_streams + db_streams - + LOGGER.info("Discovering db %s", conn_config['dbname']) + streams = discover_db(conn, conn_config.get('filter_schemas')) - if len(all_streams) == 0: + if len(streams) == 0: raise RuntimeError('0 tables were discovered across the entire cluster') - dump_catalog(all_streams) - return all_streams + dump_catalog(streams) + return streams def is_selected_via_metadata(stream): @@ -781,18 +757,22 @@ def parse_args(required_config_keys): def main_impl(): args = parse_args(REQUIRED_CONFIG_KEYS) - conn_config = {'host' : args.config['host'], - 'user' : args.config['user'], - 'password' : args.config['password'], - 'port' : args.config['port'], - 'dbname' : args.config['dbname'], - 'filter_dbs' : args.config.get('filter_dbs'), - 'filter_schemas' : args.config.get('filter_schemas'), - 'debug_lsn' : args.config.get('debug_lsn') == 'true', - 'max_run_seconds' : args.config.get('max_run_seconds', 43200), - 'break_at_end_lsn' : args.config.get('break_at_end_lsn', True), - 'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0)) - } + conn_config = { + # Required config keys + 'host': args.config['host'], + 'user': args.config['user'], + 'password': args.config['password'], + 'port': args.config['port'], + 'dbname': args.config['dbname'], + + # Optional config keys + 'tap_id': args.config.get('tap_id'), + 'filter_schemas': args.config.get('filter_schemas'), + 'debug_lsn': args.config.get('debug_lsn') == 'true', + 'max_run_seconds': args.config.get('max_run_seconds', 43200), + 'break_at_end_lsn': args.config.get('break_at_end_lsn', True), + 'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0)) + } if args.config.get('ssl') == 'true': conn_config['sslmode'] = 'require' diff --git a/tap_postgres/sync_strategies/logical_replication.py b/tap_postgres/sync_strategies/logical_replication.py index 3e1a0ecb..12b550e3 100644 --- a/tap_postgres/sync_strategies/logical_replication.py +++ b/tap_postgres/sync_strategies/logical_replication.py @@ -388,10 +388,30 @@ def consume_message(streams, state, msg, time_extracted, conn_info, end_lsn): return state +def generate_replication_slot_name(dbname, tap_id=None, prefix='pipelinewise'): + """Generate replication slot name with + + :param str dbname: Database name that will be part of the replication slot name + :param str tap_id: Optional. If provided then it will be appended to the end of the slot name + :param str prefix: Optional. Defaults to 'pipelinewise' + :return: well formatted lowercased replication slot name + :rtype: str + """ + # Add tap_id to the end of the slot name if provided + if tap_id: + tap_id = f'_{tap_id}' + # Convert None to empty string + else: + tap_id = '' + return f'{prefix}_{dbname}{tap_id}'.lower() + + def locate_replication_slot(conn_info): with post_db.open_connection(conn_info, False) as conn: with conn.cursor() as cur: - db_specific_slot = "pipelinewise_{}".format(conn_info['dbname'].lower()) + db_specific_slot = generate_replication_slot_name(dbname=conn_info['dbname'], + tap_id=conn_info['tap_id']) + cur.execute("SELECT * FROM pg_replication_slots WHERE slot_name = %s AND plugin = %s", (db_specific_slot, 'wal2json')) if len(cur.fetchall()) == 1: diff --git a/tests/test_logical_replication.py b/tests/test_logical_replication.py index dd0300cf..a2163a84 100644 --- a/tests/test_logical_replication.py +++ b/tests/test_logical_replication.py @@ -39,3 +39,26 @@ def test_streams_to_wal2json_tables(self): 'Case\\ Sensitive\\ Schema\\ With\\ Space.Case\\ Sensitive\\ Table\\ With\\ Space,' 'public.table_with_comma_\\,,' "public.table_with_quote_\\'") + + def test_generate_replication_slot_name(self): + """Validate if the replication slot name generated correctly""" + # Provide only database name + self.assertEquals(logical_replication.generate_replication_slot_name('some_db'), + 'pipelinewise_some_db') + + # Provide database name and tap_id + self.assertEquals(logical_replication.generate_replication_slot_name('some_db', + 'some_tap'), + 'pipelinewise_some_db_some_tap') + + # Provide database name, tap_id and prefix + self.assertEquals(logical_replication.generate_replication_slot_name('some_db', + 'some_tap', + prefix='custom_prefix'), + 'custom_prefix_some_db_some_tap') + + # Replication slot name should be lowercase + self.assertEquals(logical_replication.generate_replication_slot_name('SoMe_DB', + 'SoMe_TaP'), + 'pipelinewise_some_db_some_tap') +