Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Commit

Permalink
Actually run the initial LOG_BASED sync in test
Browse files Browse the repository at this point in the history
  • Loading branch information
judahrand committed Jan 18, 2022
1 parent b2857b5 commit d69d686
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
35 changes: 33 additions & 2 deletions tests/test_full_table_interruption.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import unittest.mock
import tap_postgres
from tap_postgres.sync_strategies import logical_replication
import tap_postgres.sync_strategies.full_table as full_table
import tap_postgres.sync_strategies.common as pg_common
import singer
Expand Down Expand Up @@ -46,10 +47,31 @@ def do_not_dump_catalog(catalog):
tap_postgres.dump_catalog = do_not_dump_catalog
full_table.UPDATE_BOOKMARK_PERIOD = 1

@unittest.mock.patch('tap_postgres.sync_logical_streams')

@unittest.mock.patch('tap_postgres.sync_logical_streams', wraps=tap_postgres.sync_logical_streams)
class LogicalInterruption(unittest.TestCase):
maxDiff = None

@classmethod
def setUpClass(cls):
conn_config = get_test_connection_config()
slot_name = logical_replication.generate_replication_slot_name(
dbname=conn_config['dbname'], tap_id=conn_config['tap_id']
)
with get_test_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT * FROM pg_create_logical_replication_slot('{slot_name}', 'wal2json')")

@classmethod
def tearDownClass(cls):
conn_config = get_test_connection_config()
slot_name = logical_replication.generate_replication_slot_name(
dbname=conn_config['dbname'], tap_id=conn_config['tap_id']
)
with get_test_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT * FROM pg_drop_replication_slot('{slot_name}')")

def setUp(self):
table_spec_1 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True},
{"name" : 'name', "type": "character varying"},
Expand All @@ -70,6 +92,7 @@ def test_catalog(self, mock_sync_logical_streams):

conn_config = get_test_connection_config()
streams = tap_postgres.do_discovery(conn_config)

cow_stream = [s for s in streams if s['table_name'] == 'COW'][0]
self.assertIsNotNone(cow_stream)
cow_stream = select_all_of_stream(cow_stream)
Expand Down Expand Up @@ -157,7 +180,7 @@ def test_catalog(self, mock_sync_logical_streams):

mock_sync_logical_streams.assert_called_once()

self.assertEqual(8, len(CAUGHT_MESSAGES))
self.assertEqual(10, len(CAUGHT_MESSAGES))

self.assertEqual(CAUGHT_MESSAGES[0]['type'], 'SCHEMA')

Expand Down Expand Up @@ -206,8 +229,16 @@ def test_catalog(self, mock_sync_logical_streams):
self.assertEqual(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('lsn'), end_lsn)
self.assertEqual(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('version'), new_version)

assert CAUGHT_MESSAGES[8]['type'] == 'SCHEMA'

assert isinstance(CAUGHT_MESSAGES[9], singer.messages.StateMessage)
assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('xmin') is None
assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('lsn') == end_lsn
assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('version') == new_version

class FullTableInterruption(unittest.TestCase):
maxDiff = None

def setUp(self):
table_spec_1 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True},
{"name" : 'name', "type": "character varying"},
Expand Down
6 changes: 5 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def get_test_connection_config(target_db='postgres'):
if len(missing_envs) != 0:
raise Exception("set TAP_POSTGRES_HOST, TAP_POSTGRES_USER, TAP_POSTGRES_PASSWORD, TAP_POSTGRES_PORT")

conn_config = {'host': os.environ.get('TAP_POSTGRES_HOST'),
conn_config = {'tap_id': 'test-postgres',
'max_run_seconds': 5,
'break_at_end_lsn': True,
'logical_poll_total_seconds': 2,
'host': os.environ.get('TAP_POSTGRES_HOST'),
'user': os.environ.get('TAP_POSTGRES_USER'),
'password': os.environ.get('TAP_POSTGRES_PASSWORD'),
'port': os.environ.get('TAP_POSTGRES_PORT'),
Expand Down

0 comments on commit d69d686

Please sign in to comment.