Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Add replica option #131

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions tap_postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def sync_method_for_streams(streams, state, default_replication_method):
continue

if replication_method == 'LOG_BASED' and stream_metadata.get((), {}).get('is-view'):
raise Exception(f'Logical Replication is NOT supported for views. ' \
f'Please change the replication method for {stream["tap_stream_id"]}')
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the behaviour here changed? This exception is useful isn't it? Rather than failing silently.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry this is a specific hack for HackerOne, this shouldn't be part of this PR! Same for the changes related to TOAST'ed Postgres values.


if replication_method == 'FULL_TABLE':
lookup[stream['tap_stream_id']] = 'full'
Expand Down Expand Up @@ -194,7 +193,7 @@ def sync_traditional_stream(conn_config, stream, state, sync_method, end_lsn):
return state


def sync_logical_streams(conn_config, logical_streams, state, end_lsn, state_file):
def sync_logical_streams(conn_config, logical_streams, traditional_streams, state, end_lsn, state_file):
"""
Sync streams that use LOG_BASED method
"""
Expand All @@ -212,10 +211,20 @@ def sync_logical_streams(conn_config, logical_streams, state, end_lsn, state_fil
selected_streams.add("{}".format(stream['tap_stream_id']))

new_state = dict(currently_syncing=state['currently_syncing'], bookmarks={})
traditional_stream_ids = [s['tap_stream_id'] for s in traditional_streams]

for stream, bookmark in state['bookmarks'].items():
if bookmark == {} or bookmark['last_replication_method'] != 'LOG_BASED' or stream in selected_streams:
if (
bookmark == {}
or bookmark['last_replication_method'] != 'LOG_BASED'
or stream in selected_streams
# The first time a LOG_BASED stream runs it needs to do an
# initial full table sync, and so will be treated as a
# traditional stream.
or (stream in traditional_stream_ids and bookmark['last_replication_method'] == 'LOG_BASED')
):
new_state['bookmarks'][stream] = bookmark

state = new_state

state = logical_replication.sync_tables(conn_config, logical_streams, state, end_lsn, state_file)
Expand Down Expand Up @@ -319,7 +328,7 @@ def do_sync(conn_config, catalog, default_replication_method, state, state_file=
for dbname, streams in itertools.groupby(logical_streams,
lambda s: metadata.to_map(s['metadata']).get(()).get('database-name')):
conn_config['dbname'] = dbname
state = sync_logical_streams(conn_config, list(streams), state, end_lsn, state_file)
state = sync_logical_streams(conn_config, list(streams), traditional_streams, state, end_lsn, state_file)
return state


Expand Down Expand Up @@ -405,9 +414,22 @@ def main_impl():
'debug_lsn': args.config.get('debug_lsn') == 'true',
'max_run_seconds': args.config.get('max_run_seconds', 43200),
'break_at_end_lsn': args.config.get('break_at_end_lsn', True),
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0))
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0)),
'use_replica': args.config.get('use_replica', False),
}

if conn_config['use_replica']:
replica_config = {
# Required replica config keys
'replica_host': args.config['replica_host'],
'replica_user': args.config['replica_user'],
'replica_password': args.config['replica_password'],
Comment on lines +425 to +426
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that would be a great default!

'replica_port': args.config['replica_port'],
'replica_dbname': args.config['replica_dbname'],
}

conn_config = { **conn_config, **replica_config }

if args.config.get('ssl') == 'true':
conn_config['sslmode'] = 'require'

Expand Down
25 changes: 19 additions & 6 deletions tap_postgres/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,27 @@ def fully_qualified_table_name(schema, table):
return '"{}"."{}"'.format(canonicalize_identifier(schema), canonicalize_identifier(table))


def open_connection(conn_config, logical_replication=False):
def open_connection(conn_config, logical_replication=False, primary_connection=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's a good idea!

if not primary_connection and conn_config['use_replica']:
host_key = "replica_host"
dbname_key = "replica_dbname"
user_key = "replica_user"
password_key = "replica_password"
port_key = "replica_port"
else:
host_key = "host"
dbname_key = "dbname"
user_key = "user"
password_key = "password"
port_key = "port"

cfg = {
'application_name': 'pipelinewise',
'host': conn_config['host'],
'dbname': conn_config['dbname'],
'user': conn_config['user'],
'password': conn_config['password'],
'port': conn_config['port'],
'host': conn_config[host_key],
'dbname': conn_config[dbname_key],
'user': conn_config[user_key],
'password': conn_config[password_key],
'port': conn_config[port_key],
'connect_timeout': 30
}

Expand Down
12 changes: 6 additions & 6 deletions tap_postgres/sync_strategies/logical_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class UnsupportedPayloadKindError(Exception):

# pylint: disable=invalid-name,missing-function-docstring,too-many-branches,too-many-statements,too-many-arguments
def get_pg_version(conn_info):
with post_db.open_connection(conn_info, False) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
cur.execute("SELECT setting::int AS version FROM pg_settings WHERE name='server_version_num'")
version = cur.fetchone()[0]
Expand Down Expand Up @@ -92,7 +92,7 @@ def fetch_current_lsn(conn_config):
if version < 90400:
raise Exception('Logical replication not supported before PostgreSQL 9.4')

with post_db.open_connection(conn_config, False) as conn:
with post_db.open_connection(conn_config, False, True) as conn:
with conn.cursor() as cur:
# Use version specific lsn command
if version >= 100000:
Expand Down Expand Up @@ -137,7 +137,7 @@ def create_hstore_elem_query(elem):


def create_hstore_elem(conn_info, elem):
with post_db.open_connection(conn_info) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
query = create_hstore_elem_query(elem)
cur.execute(query)
Expand All @@ -150,7 +150,7 @@ def create_array_elem(elem, sql_datatype, conn_info):
if elem is None:
return None

with post_db.open_connection(conn_info) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
if sql_datatype == 'bit[]':
cast_datatype = 'boolean[]'
Expand Down Expand Up @@ -516,7 +516,7 @@ def locate_replication_slot_by_cur(cursor, dbname, tap_id=None):


def locate_replication_slot(conn_info):
with post_db.open_connection(conn_info, False) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
return locate_replication_slot_by_cur(cur, conn_info['dbname'], conn_info['tap_id'])

Expand Down Expand Up @@ -575,7 +575,7 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file):
version = get_pg_version(conn_info)

# Create replication connection and cursor
conn = post_db.open_connection(conn_info, True)
conn = post_db.open_connection(conn_info, True, True)
cur = conn.cursor()

# Set session wal_sender_timeout for PG12 and above
Expand Down