Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Perform logical replication after initial sync #144

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions tap_postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def sync_method_for_streams(streams, state, default_replication_method):
# finishing previously interrupted full-table (first stage of logical replication)
lookup[stream['tap_stream_id']] = 'logical_initial_interrupted'
traditional_steams.append(stream)
# do any required logical replication after inital sync is complete
logical_streams.append(stream)

# inconsistent state
elif get_bookmark(state, stream['tap_stream_id'], 'xmin') and \
Expand All @@ -142,6 +144,8 @@ def sync_method_for_streams(streams, state, default_replication_method):
# initial full-table phase of logical replication
lookup[stream['tap_stream_id']] = 'logical_initial'
traditional_steams.append(stream)
# do any required logical replication after inital sync is complete
logical_streams.append(stream)

else: # no xmin but we have an lsn
# initial stage of logical replication(full-table) has been completed. moving onto pure logical replication
Expand Down
47 changes: 42 additions & 5 deletions tests/test_full_table_interruption.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
import unittest.mock
import tap_postgres
from tap_postgres.sync_strategies import logical_replication
import tap_postgres.sync_strategies.full_table as full_table
import tap_postgres.sync_strategies.common as pg_common
import singer
Expand Down Expand Up @@ -45,9 +47,31 @@ def do_not_dump_catalog(catalog):
tap_postgres.dump_catalog = do_not_dump_catalog
full_table.UPDATE_BOOKMARK_PERIOD = 1


@unittest.mock.patch('tap_postgres.sync_logical_streams', wraps=tap_postgres.sync_logical_streams)
class LogicalInterruption(unittest.TestCase):
maxDiff = None

@classmethod
def setUpClass(cls):
conn_config = get_test_connection_config()
slot_name = logical_replication.generate_replication_slot_name(
dbname=conn_config['dbname'], tap_id=conn_config['tap_id']
)
with get_test_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT * FROM pg_create_logical_replication_slot('{slot_name}', 'wal2json')")

@classmethod
def tearDownClass(cls):
conn_config = get_test_connection_config()
slot_name = logical_replication.generate_replication_slot_name(
dbname=conn_config['dbname'], tap_id=conn_config['tap_id']
)
with get_test_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT * FROM pg_drop_replication_slot('{slot_name}')")

def setUp(self):
table_spec_1 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True},
{"name" : 'name', "type": "character varying"},
Expand All @@ -62,12 +86,13 @@ def setUp(self):
global CAUGHT_MESSAGES
CAUGHT_MESSAGES.clear()

def test_catalog(self):
def test_catalog(self, mock_sync_logical_streams):
singer.write_message = singer_write_message_no_cow
pg_common.write_schema_message = singer_write_message_ok

conn_config = get_test_connection_config()
streams = tap_postgres.do_discovery(conn_config)

cow_stream = [s for s in streams if s['table_name'] == 'COW'][0]
self.assertIsNotNone(cow_stream)
cow_stream = select_all_of_stream(cow_stream)
Expand All @@ -90,7 +115,7 @@ def test_catalog(self):
insert_record(cur, 'COW', cow_rec)

conn.close()

blew_up_on_cow = False
state = {}
#the initial phase of cows logical replication will be a full table.
Expand All @@ -102,6 +127,8 @@ def test_catalog(self):

self.assertTrue(blew_up_on_cow)

mock_sync_logical_streams.assert_not_called()

self.assertEqual(7, len(CAUGHT_MESSAGES))

self.assertEqual(CAUGHT_MESSAGES[0]['type'], 'SCHEMA')
Expand Down Expand Up @@ -151,7 +178,9 @@ def test_catalog(self):
CAUGHT_MESSAGES.clear()
tap_postgres.do_sync(get_test_connection_config(), {'streams' : streams}, None, old_state)

self.assertEqual(8, len(CAUGHT_MESSAGES))
mock_sync_logical_streams.assert_called_once()

self.assertEqual(10, len(CAUGHT_MESSAGES))

self.assertEqual(CAUGHT_MESSAGES[0]['type'], 'SCHEMA')

Expand Down Expand Up @@ -200,8 +229,16 @@ def test_catalog(self):
self.assertEqual(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('lsn'), end_lsn)
self.assertEqual(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('version'), new_version)

assert CAUGHT_MESSAGES[8]['type'] == 'SCHEMA'

assert isinstance(CAUGHT_MESSAGES[9], singer.messages.StateMessage)
assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('xmin') is None
assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('lsn') == end_lsn
assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('version') == new_version

Comment on lines +266 to +272
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use TestCase assert* methods instead?

Copy link

@josescuderoh josescuderoh Apr 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Samira-El I'm facing this issue and planning on contributing the fix. Is your feedback related to line 268 only? That is change it to assertIsInstance(CAUGHT_MESSAGES[9], singer.messages.StateMessage). Asking this since I see all other assert statements don't use TestCase either.

class FullTableInterruption(unittest.TestCase):
maxDiff = None

def setUp(self):
table_spec_1 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True},
{"name" : 'name', "type": "character varying"},
Expand Down Expand Up @@ -238,7 +275,7 @@ def test_catalog(self):

conn = get_test_connection()
conn.autocommit = True

with conn.cursor() as cur:
cow_rec = {'name': 'betty', 'colour': 'blue'}
insert_record(cur, 'COW', {'name': 'betty', 'colour': 'blue'})
Expand All @@ -256,7 +293,7 @@ def test_catalog(self):

state = {}
blew_up_on_cow = False

#this will sync the CHICKEN but then blow up on the COW
try:
tap_postgres.do_sync(get_test_connection_config(), {'streams' : streams}, None, state)
Expand Down
6 changes: 5 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def get_test_connection_config(target_db='postgres'):
if len(missing_envs) != 0:
raise Exception("set TAP_POSTGRES_HOST, TAP_POSTGRES_USER, TAP_POSTGRES_PASSWORD, TAP_POSTGRES_PORT")

conn_config = {'host': os.environ.get('TAP_POSTGRES_HOST'),
conn_config = {'tap_id': 'test-postgres',
'max_run_seconds': 5,
'break_at_end_lsn': True,
'logical_poll_total_seconds': 2,
'host': os.environ.get('TAP_POSTGRES_HOST'),
'user': os.environ.get('TAP_POSTGRES_USER'),
'password': os.environ.get('TAP_POSTGRES_PASSWORD'),
'port': os.environ.get('TAP_POSTGRES_PORT'),
Expand Down