From d69d68653216692281ecf9ad4df762235a3ccec6 Mon Sep 17 00:00:00 2001 From: Judah Rand <17158624+judahrand@users.noreply.github.com> Date: Tue, 18 Jan 2022 12:29:03 +0000 Subject: [PATCH] Actually run the initial `LOG_BASED` sync in test --- tests/test_full_table_interruption.py | 35 +++++++++++++++++++++++++-- tests/utils.py | 6 ++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tests/test_full_table_interruption.py b/tests/test_full_table_interruption.py index 1edabaed..567a2c63 100644 --- a/tests/test_full_table_interruption.py +++ b/tests/test_full_table_interruption.py @@ -1,6 +1,7 @@ import unittest import unittest.mock import tap_postgres +from tap_postgres.sync_strategies import logical_replication import tap_postgres.sync_strategies.full_table as full_table import tap_postgres.sync_strategies.common as pg_common import singer @@ -46,10 +47,31 @@ def do_not_dump_catalog(catalog): tap_postgres.dump_catalog = do_not_dump_catalog full_table.UPDATE_BOOKMARK_PERIOD = 1 -@unittest.mock.patch('tap_postgres.sync_logical_streams') + +@unittest.mock.patch('tap_postgres.sync_logical_streams', wraps=tap_postgres.sync_logical_streams) class LogicalInterruption(unittest.TestCase): maxDiff = None + @classmethod + def setUpClass(cls): + conn_config = get_test_connection_config() + slot_name = logical_replication.generate_replication_slot_name( + dbname=conn_config['dbname'], tap_id=conn_config['tap_id'] + ) + with get_test_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"SELECT * FROM pg_create_logical_replication_slot('{slot_name}', 'wal2json')") + + @classmethod + def tearDownClass(cls): + conn_config = get_test_connection_config() + slot_name = logical_replication.generate_replication_slot_name( + dbname=conn_config['dbname'], tap_id=conn_config['tap_id'] + ) + with get_test_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"SELECT * FROM pg_drop_replication_slot('{slot_name}')") + def setUp(self): table_spec_1 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True}, {"name" : 'name', "type": "character varying"}, @@ -70,6 +92,7 @@ def test_catalog(self, mock_sync_logical_streams): conn_config = get_test_connection_config() streams = tap_postgres.do_discovery(conn_config) + cow_stream = [s for s in streams if s['table_name'] == 'COW'][0] self.assertIsNotNone(cow_stream) cow_stream = select_all_of_stream(cow_stream) @@ -157,7 +180,7 @@ def test_catalog(self, mock_sync_logical_streams): mock_sync_logical_streams.assert_called_once() - self.assertEqual(8, len(CAUGHT_MESSAGES)) + self.assertEqual(10, len(CAUGHT_MESSAGES)) self.assertEqual(CAUGHT_MESSAGES[0]['type'], 'SCHEMA') @@ -206,8 +229,16 @@ def test_catalog(self, mock_sync_logical_streams): self.assertEqual(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('lsn'), end_lsn) self.assertEqual(CAUGHT_MESSAGES[7].value['bookmarks']['public-COW'].get('version'), new_version) + assert CAUGHT_MESSAGES[8]['type'] == 'SCHEMA' + + assert isinstance(CAUGHT_MESSAGES[9], singer.messages.StateMessage) + assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('xmin') is None + assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('lsn') == end_lsn + assert CAUGHT_MESSAGES[9].value['bookmarks']['public-COW'].get('version') == new_version + class FullTableInterruption(unittest.TestCase): maxDiff = None + def setUp(self): table_spec_1 = {"columns": [{"name": "id", "type" : "serial", "primary_key" : True}, {"name" : 'name', "type": "character varying"}, diff --git a/tests/utils.py b/tests/utils.py index 881cfd3c..6253d0f6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,7 +18,11 @@ def get_test_connection_config(target_db='postgres'): if len(missing_envs) != 0: raise Exception("set TAP_POSTGRES_HOST, TAP_POSTGRES_USER, TAP_POSTGRES_PASSWORD, TAP_POSTGRES_PORT") - conn_config = {'host': os.environ.get('TAP_POSTGRES_HOST'), + conn_config = {'tap_id': 'test-postgres', + 'max_run_seconds': 5, + 'break_at_end_lsn': True, + 'logical_poll_total_seconds': 2, + 'host': os.environ.get('TAP_POSTGRES_HOST'), 'user': os.environ.get('TAP_POSTGRES_USER'), 'password': os.environ.get('TAP_POSTGRES_PASSWORD'), 'port': os.environ.get('TAP_POSTGRES_PORT'),