Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Commit

Permalink
[AP-655] Handle to LOG_BASED replicate muliple databases by multiple …
Browse files Browse the repository at this point in the history
…taps (#50)
  • Loading branch information
koszti authored Apr 20, 2020
1 parent 5810719 commit 4a88945
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 43 deletions.
64 changes: 22 additions & 42 deletions tap_postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,43 +416,19 @@ def attempt_connection_to_db(conn_config, dbname):


def dump_catalog(all_streams):
json.dump({'streams' : all_streams}, sys.stdout, indent=2)
json.dump({'streams': all_streams}, sys.stdout, indent=2)


def do_discovery(conn_config):
all_streams = []

with post_db.open_connection(conn_config) as conn:
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor, name='pipelinewise') as cur:
cur.itersize = post_db.cursor_iter_size
sql = """SELECT datname
FROM pg_database
WHERE datistemplate = false
AND datname != 'rdsadmin'"""

if conn_config.get('filter_dbs'):
sql = post_db.filter_dbs_sql_clause(sql, conn_config['filter_dbs'])

LOGGER.info("Running DB discovery: %s with itersize %s", sql, cur.itersize)
cur.execute(sql)
found_dbs = (row[0] for row in cur.fetchall())

filter_dbs = filter(lambda dbname: attempt_connection_to_db(conn_config, dbname), found_dbs)

for db_row in filter_dbs:
dbname = db_row
LOGGER.info("Discovering db %s", dbname)
conn_config['dbname'] = dbname
with post_db.open_connection(conn_config) as conn:
db_streams = discover_db(conn, conn_config.get('filter_schemas'))
all_streams = all_streams + db_streams

LOGGER.info("Discovering db %s", conn_config['dbname'])
streams = discover_db(conn, conn_config.get('filter_schemas'))

if len(all_streams) == 0:
if len(streams) == 0:
raise RuntimeError('0 tables were discovered across the entire cluster')

dump_catalog(all_streams)
return all_streams
dump_catalog(streams)
return streams


def is_selected_via_metadata(stream):
Expand Down Expand Up @@ -781,18 +757,22 @@ def parse_args(required_config_keys):

def main_impl():
args = parse_args(REQUIRED_CONFIG_KEYS)
conn_config = {'host' : args.config['host'],
'user' : args.config['user'],
'password' : args.config['password'],
'port' : args.config['port'],
'dbname' : args.config['dbname'],
'filter_dbs' : args.config.get('filter_dbs'),
'filter_schemas' : args.config.get('filter_schemas'),
'debug_lsn' : args.config.get('debug_lsn') == 'true',
'max_run_seconds' : args.config.get('max_run_seconds', 43200),
'break_at_end_lsn' : args.config.get('break_at_end_lsn', True),
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0))
}
conn_config = {
# Required config keys
'host': args.config['host'],
'user': args.config['user'],
'password': args.config['password'],
'port': args.config['port'],
'dbname': args.config['dbname'],

# Optional config keys
'tap_id': args.config.get('tap_id'),
'filter_schemas': args.config.get('filter_schemas'),
'debug_lsn': args.config.get('debug_lsn') == 'true',
'max_run_seconds': args.config.get('max_run_seconds', 43200),
'break_at_end_lsn': args.config.get('break_at_end_lsn', True),
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0))
}

if args.config.get('ssl') == 'true':
conn_config['sslmode'] = 'require'
Expand Down
22 changes: 21 additions & 1 deletion tap_postgres/sync_strategies/logical_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,30 @@ def consume_message(streams, state, msg, time_extracted, conn_info, end_lsn):
return state


def generate_replication_slot_name(dbname, tap_id=None, prefix='pipelinewise'):
"""Generate replication slot name with
:param str dbname: Database name that will be part of the replication slot name
:param str tap_id: Optional. If provided then it will be appended to the end of the slot name
:param str prefix: Optional. Defaults to 'pipelinewise'
:return: well formatted lowercased replication slot name
:rtype: str
"""
# Add tap_id to the end of the slot name if provided
if tap_id:
tap_id = f'_{tap_id}'
# Convert None to empty string
else:
tap_id = ''
return f'{prefix}_{dbname}{tap_id}'.lower()


def locate_replication_slot(conn_info):
with post_db.open_connection(conn_info, False) as conn:
with conn.cursor() as cur:
db_specific_slot = "pipelinewise_{}".format(conn_info['dbname'].lower())
db_specific_slot = generate_replication_slot_name(dbname=conn_info['dbname'],
tap_id=conn_info['tap_id'])

cur.execute("SELECT * FROM pg_replication_slots WHERE slot_name = %s AND plugin = %s",
(db_specific_slot, 'wal2json'))
if len(cur.fetchall()) == 1:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_logical_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,26 @@ def test_streams_to_wal2json_tables(self):
'Case\\ Sensitive\\ Schema\\ With\\ Space.Case\\ Sensitive\\ Table\\ With\\ Space,'
'public.table_with_comma_\\,,'
"public.table_with_quote_\\'")

def test_generate_replication_slot_name(self):
"""Validate if the replication slot name generated correctly"""
# Provide only database name
self.assertEquals(logical_replication.generate_replication_slot_name('some_db'),
'pipelinewise_some_db')

# Provide database name and tap_id
self.assertEquals(logical_replication.generate_replication_slot_name('some_db',
'some_tap'),
'pipelinewise_some_db_some_tap')

# Provide database name, tap_id and prefix
self.assertEquals(logical_replication.generate_replication_slot_name('some_db',
'some_tap',
prefix='custom_prefix'),
'custom_prefix_some_db_some_tap')

# Replication slot name should be lowercase
self.assertEquals(logical_replication.generate_replication_slot_name('SoMe_DB',
'SoMe_TaP'),
'pipelinewise_some_db_some_tap')

0 comments on commit 4a88945

Please sign in to comment.