Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Add replica option #131

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tap_postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,22 @@ def main_impl():
'debug_lsn': args.config.get('debug_lsn') == 'true',
'max_run_seconds': args.config.get('max_run_seconds', 43200),
'break_at_end_lsn': args.config.get('break_at_end_lsn', True),
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0))
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0)),
'use_replica': args.config.get('use_replica', False),
}

if conn_config['use_replica']:
replica_config = {
# Required replica config keys
'replica_host': args.config['replica_host'],
'replica_user': args.config['replica_user'],
'replica_password': args.config['replica_password'],
Comment on lines +425 to +426
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that would be a great default!

'replica_port': args.config['replica_port'],
'replica_dbname': args.config['replica_dbname'],
}

conn_config = { **conn_config, **replica_config }

if args.config.get('ssl') == 'true':
conn_config['sslmode'] = 'require'

Expand Down
25 changes: 19 additions & 6 deletions tap_postgres/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,27 @@ def fully_qualified_table_name(schema, table):
return '"{}"."{}"'.format(canonicalize_identifier(schema), canonicalize_identifier(table))


def open_connection(conn_config, logical_replication=False):
def open_connection(conn_config, logical_replication=False, primary_connection=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's a good idea!

if not primary_connection and conn_config['use_replica']:
host_key = "replica_host"
dbname_key = "replica_dbname"
user_key = "replica_user"
password_key = "replica_password"
port_key = "replica_port"
else:
host_key = "host"
dbname_key = "dbname"
user_key = "user"
password_key = "password"
port_key = "port"

cfg = {
'application_name': 'pipelinewise',
'host': conn_config['host'],
'dbname': conn_config['dbname'],
'user': conn_config['user'],
'password': conn_config['password'],
'port': conn_config['port'],
'host': conn_config[host_key],
'dbname': conn_config[dbname_key],
'user': conn_config[user_key],
'password': conn_config[password_key],
'port': conn_config[port_key],
'connect_timeout': 30
}

Expand Down
12 changes: 6 additions & 6 deletions tap_postgres/sync_strategies/logical_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class UnsupportedPayloadKindError(Exception):

# pylint: disable=invalid-name,missing-function-docstring,too-many-branches,too-many-statements,too-many-arguments
def get_pg_version(conn_info):
with post_db.open_connection(conn_info, False) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
cur.execute("SELECT setting::int AS version FROM pg_settings WHERE name='server_version_num'")
version = cur.fetchone()[0]
Expand Down Expand Up @@ -92,7 +92,7 @@ def fetch_current_lsn(conn_config):
if version < 90400:
raise Exception('Logical replication not supported before PostgreSQL 9.4')

with post_db.open_connection(conn_config, False) as conn:
with post_db.open_connection(conn_config, False, True) as conn:
with conn.cursor() as cur:
# Use version specific lsn command
if version >= 100000:
Expand Down Expand Up @@ -137,7 +137,7 @@ def create_hstore_elem_query(elem):


def create_hstore_elem(conn_info, elem):
with post_db.open_connection(conn_info) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
query = create_hstore_elem_query(elem)
cur.execute(query)
Expand All @@ -150,7 +150,7 @@ def create_array_elem(elem, sql_datatype, conn_info):
if elem is None:
return None

with post_db.open_connection(conn_info) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
if sql_datatype == 'bit[]':
cast_datatype = 'boolean[]'
Expand Down Expand Up @@ -516,7 +516,7 @@ def locate_replication_slot_by_cur(cursor, dbname, tap_id=None):


def locate_replication_slot(conn_info):
with post_db.open_connection(conn_info, False) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
return locate_replication_slot_by_cur(cur, conn_info['dbname'], conn_info['tap_id'])

Expand Down Expand Up @@ -575,7 +575,7 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file):
version = get_pg_version(conn_info)

# Create replication connection and cursor
conn = post_db.open_connection(conn_info, True)
conn = post_db.open_connection(conn_info, True, True)
cur = conn.cursor()

# Set session wal_sender_timeout for PG12 and above
Expand Down