diff --git a/aiven_db_migrate/migrate/pgmigrate.py b/aiven_db_migrate/migrate/pgmigrate.py index 1997435..a842b2e 100644 --- a/aiven_db_migrate/migrate/pgmigrate.py +++ b/aiven_db_migrate/migrate/pgmigrate.py @@ -35,6 +35,15 @@ MAX_CLI_LEN = 2097152 # getconf ARG_MAX +class ReplicationObjectType(enum.Enum): + PUBLICATION = "pub" + SUBSCRIPTION = "sub" + REPLICATION_SLOT = "slot" + + def get_display_name(self) -> str: + return self.name.replace("_", " ").lower() + + @dataclass class PGExtension: name: str @@ -97,6 +106,7 @@ class PGRole: class PGCluster: """PGCluster is a collection of databases managed by a single PostgreSQL server instance""" + DB_OBJECT_PREFIX = "managed_db_migrate" conn_info: Dict[str, Any] _databases: Dict[str, PGDatabase] _params: Dict[str, str] @@ -438,6 +448,10 @@ def mangle_db_name(self, db_name: str) -> str: return db_name return hashlib.md5(db_name.encode()).hexdigest() + def get_replication_object_name(self, dbname: str, replication_obj_type: ReplicationObjectType) -> str: + mangled_name = self.mangle_db_name(dbname) + return f"{self.DB_OBJECT_PREFIX}_{mangled_name}_{replication_obj_type.value}" + class PGSource(PGCluster): """Source PostgreSQL cluster""" @@ -454,8 +468,10 @@ def get_size(self, *, dbname, only_tables: Optional[List[str]] = None) -> float: return result[0]["size"] or 0 def create_publication(self, *, dbname: str, only_tables: Optional[List[str]] = None) -> str: - mangled_name = self.mangle_db_name(dbname) - pubname = f"managed_db_migrate_{mangled_name}_pub" + pubname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.PUBLICATION, + ) validate_pg_identifier_length(pubname) pub_options: Union[List[str], str] @@ -498,8 +514,11 @@ def create_publication(self, *, dbname: str, only_tables: Optional[List[str]] = return pubname def create_replication_slot(self, *, dbname: str) -> str: - mangled_name = self.mangle_db_name(dbname) - slotname = f"managed_db_migrate_{mangled_name}_slot" + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) + validate_pg_identifier_length(slotname) self.log.info("Creating replication slot %r in database %r", slotname, dbname) @@ -516,12 +535,17 @@ def create_replication_slot(self, *, dbname: str) -> str: return slotname - def get_publication(self, *, dbname: str, pubname: str) -> Dict[str, Any]: + def get_publication(self, *, dbname: str) -> Dict[str, Any]: + pubname = self.get_replication_object_name(dbname=dbname, replication_obj_type=ReplicationObjectType.PUBLICATION) # publications as per database so connect to given database result = self.c("SELECT * FROM pg_catalog.pg_publication WHERE pubname = %s", args=(pubname, ), dbname=dbname) return result[0] if result else {} - def get_replication_slot(self, *, dbname: str, slotname: str) -> Dict[str, Any]: + def get_replication_slot(self, *, dbname: str) -> Dict[str, Any]: + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) result = self.c( "SELECT * from pg_catalog.pg_replication_slots WHERE database = %s AND slot_name = %s", args=( @@ -532,7 +556,11 @@ def get_replication_slot(self, *, dbname: str, slotname: str) -> Dict[str, Any]: ) return result[0] if result else {} - def replication_in_sync(self, *, dbname: str, slotname: str, max_replication_lag: int) -> Tuple[bool, str]: + def replication_in_sync(self, *, dbname: str, max_replication_lag: int) -> Tuple[bool, str]: + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) exists = self.c( "SELECT 1 FROM pg_catalog.pg_replication_slots WHERE slot_name = %s", args=(slotname, ), dbname=dbname ) @@ -573,28 +601,63 @@ def large_objects_present(self, *, dbname: str) -> bool: self.log.warning("Unable to determine if large objects present in database %r", dbname) return False - def cleanup(self, *, dbname: str, pubname: str, slotname: str): - # publications as per database so connect to correct database - pub = self.get_publication(dbname=dbname, pubname=pubname) - if pub: - self.log.info("Dropping publication %r from database %r", pub, dbname) - self.c("DROP PUBLICATION {}".format(pub["pubname"]), dbname=dbname, return_rows=0) - slot = self.get_replication_slot(dbname=dbname, slotname=slotname) - if slot: - self.log.info("Dropping replication slot %r from database %r", slot, dbname) - self.c( - "SELECT 1 FROM pg_catalog.pg_drop_replication_slot(%s)", - args=(slot["slot_name"], ), - dbname=dbname, - return_rows=0 + def cleanup(self, *, dbname: str): + self._cleanup_replication_object(dbname=dbname, replication_object_type=ReplicationObjectType.PUBLICATION) + self._cleanup_replication_object(dbname=dbname, replication_object_type=ReplicationObjectType.REPLICATION_SLOT) + + def _cleanup_replication_object(self, dbname: str, replication_object_type: ReplicationObjectType): + rep_obj_type_display_name = replication_object_type.get_display_name() + + rep_obj_name = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=replication_object_type, + ) + try: + if ReplicationObjectType.PUBLICATION is replication_object_type: + rep_obj = self.get_publication(dbname=dbname) + delete_query = f"DROP PUBLICATION {rep_obj_name};" + args = () + else: + rep_obj = self.get_replication_slot(dbname=dbname) + delete_query = f"SELECT 1 FROM pg_catalog.pg_drop_replication_slot(%s)" + args = (rep_obj_name, ) + + if not rep_obj: + return + + self.log.info( + "Dropping %r %r from database %r", + rep_obj_type_display_name, + rep_obj_name, + dbname, + ) + self.c(delete_query, args=args, dbname=dbname, return_rows=0) + except Exception as exc: + self.log.error( + "Failed to drop %r %r for database %r: %s", + rep_obj_type_display_name, + rep_obj_name, + dbname, + exc, ) class PGTarget(PGCluster): """Target PostgreSQL cluster""" - def create_subscription(self, *, conn_str: str, pubname: str, slotname: str, dbname: str) -> str: - mangled_name = self.mangle_db_name(dbname) - subname = f"managed_db_migrate_{mangled_name}_sub" + def create_subscription(self, *, conn_str: str, dbname: str) -> str: + pubname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.PUBLICATION, + ) + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) + + subname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) validate_pg_identifier_length(subname) has_aiven_extras = self.has_aiven_extras(dbname=dbname) @@ -630,7 +693,11 @@ def create_subscription(self, *, conn_str: str, pubname: str, slotname: str, dbn return subname - def get_subscription(self, *, dbname: str, subname: str) -> Dict[str, Any]: + def get_subscription(self, *, dbname: str) -> Dict[str, Any]: + subname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) if self.has_aiven_extras(dbname=dbname): result = self.c( "SELECT * FROM aiven_extras.pg_list_all_subscriptions() WHERE subname = %s", args=(subname, ), dbname=dbname @@ -640,7 +707,11 @@ def get_subscription(self, *, dbname: str, subname: str) -> Dict[str, Any]: result = self.c("SELECT * FROM pg_catalog.pg_subscription WHERE subname = %s", args=(subname, ), dbname=dbname) return result[0] if result else {} - def replication_in_sync(self, *, dbname: str, subname: str, write_lsn: str, max_replication_lag: int) -> bool: + def replication_in_sync(self, *, dbname: str, write_lsn: str, max_replication_lag: int) -> bool: + subname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) status = self.c( """ SELECT stat.*, @@ -664,23 +735,32 @@ def replication_in_sync(self, *, dbname: str, subname: str, write_lsn: str, max_ self.log.warning("Replication status not available for %r in database %r", subname, dbname) return False - def cleanup(self, *, dbname: str, subname: str): - sub = self.get_subscription(dbname=dbname, subname=subname) - if sub: - self.log.info("Dropping subscription %r from database %r", sub["subname"], dbname) + def cleanup(self, *, dbname: str): + subname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.SUBSCRIPTION, + ) + try: + if not self.get_subscription(dbname=dbname): + return + + self.log.info("Dropping subscription %r from database %r", subname, dbname) if self.has_aiven_extras(dbname=dbname): # NOTE: this drops also replication slot from source - self.c( - "SELECT * FROM aiven_extras.pg_drop_subscription(%s)", - args=(sub["subname"], ), - dbname=dbname, - return_rows=0 - ) + self.c("SELECT * FROM aiven_extras.pg_drop_subscription(%s)", args=(subname, ), dbname=dbname, return_rows=0) else: # requires superuser or superuser-like privileges, such as "rds_replication" role in AWS RDS - self.c("ALTER SUBSCRIPTION {} DISABLE".format(sub["subname"]), dbname=dbname, return_rows=0) - self.c("ALTER SUBSCRIPTION {} SET (slot_name = NONE)".format(sub["subname"]), dbname=dbname, return_rows=0) - self.c("DROP SUBSCRIPTION {}".format(sub["subname"]), dbname=dbname, return_rows=0) + self.c("ALTER SUBSCRIPTION {} DISABLE".format(subname), dbname=dbname, return_rows=0) + self.c("ALTER SUBSCRIPTION {} SET (slot_name = NONE)".format(subname), dbname=dbname, return_rows=0) + self.c("DROP SUBSCRIPTION {}".format(subname), dbname=dbname, return_rows=0) + + except Exception as exc: + self.log.error( + "Failed to drop subscription %r for database %r: %s", + subname, + dbname, + exc, + ) @enum.unique @@ -1256,42 +1336,39 @@ def _dump_data(self, *, db: PGDatabase) -> PGMigrateStatus: raise PGDataDumpFailedError(f"Failed to dump data: {subtask!r}") return PGMigrateStatus.done - def _wait_for_replication(self, *, dbname: str, slotname: str, subname: str, check_interval: float = 2.0): + def _wait_for_replication(self, *, dbname: str, check_interval: float = 2.0): while True: - in_sync, write_lsn = self.source.replication_in_sync( - dbname=dbname, slotname=slotname, max_replication_lag=self.max_replication_lag - ) + in_sync, write_lsn = self.source.replication_in_sync(dbname=dbname, max_replication_lag=self.max_replication_lag) if in_sync and self.target.replication_in_sync( - dbname=dbname, subname=subname, write_lsn=write_lsn, max_replication_lag=self.max_replication_lag + dbname=dbname, write_lsn=write_lsn, max_replication_lag=self.max_replication_lag ): break time.sleep(check_interval) def _db_replication(self, *, db: PGDatabase) -> PGMigrateStatus: dbname = db.dbname - pubname = slotname = subname = None try: tables = self.filter_tables(db) or [] - pubname = self.source.create_publication(dbname=dbname, only_tables=tables) - slotname = self.source.create_replication_slot(dbname=dbname) - subname = self.target.create_subscription( - conn_str=self.source.conn_str(dbname=dbname), pubname=pubname, slotname=slotname, dbname=dbname - ) + self.source.create_publication(dbname=dbname, only_tables=tables) + self.source.create_replication_slot(dbname=dbname) + self.target.create_subscription(conn_str=self.source.conn_str(dbname=dbname), dbname=dbname) + except psycopg2.ProgrammingError as e: self.log.error("Encountered error: %r, cleaning up", e) - if subname: - self.target.cleanup(dbname=dbname, subname=subname) - if pubname and slotname: - self.source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname) + + # clean-up replication objects, avoid leaving traces specially in source + self.target.cleanup(dbname=dbname) + self.source.cleanup(dbname=dbname) raise self.log.info("Logical replication setup successful for database %r", dbname) if self.max_replication_lag > -1: - self._wait_for_replication(dbname=dbname, slotname=slotname, subname=subname) + self._wait_for_replication(dbname=dbname) if self.stop_replication: - self.target.cleanup(dbname=dbname, subname=subname) - self.source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname) + self.target.cleanup(dbname=dbname) + self.source.cleanup(dbname=dbname) return PGMigrateStatus.done + # leaving replication running return PGMigrateStatus.running diff --git a/aiven_db_migrate/migrate/version.py b/aiven_db_migrate/migrate/version.py index c8811fb..147c1e4 100644 --- a/aiven_db_migrate/migrate/version.py +++ b/aiven_db_migrate/migrate/version.py @@ -1 +1 @@ -__version__ = "0.1.5-2-gfd83d7c" +__version__ = "0.1.5-2-ga13c553" diff --git a/test/conftest.py b/test/conftest.py index 6c3815a..cf7ef88 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,7 +4,7 @@ from _pytest.fixtures import FixtureRequest from _pytest.tmpdir import TempPathFactory -from aiven_db_migrate.migrate.pgmigrate import PGTarget +from aiven_db_migrate.migrate.pgmigrate import PGTarget, ReplicationObjectType from contextlib import contextmanager from copy import copy from functools import partial, wraps @@ -154,8 +154,12 @@ def _drop_replication_slot(pg_runner_: PGRunner, slot_name: str) -> None: break # Found it, no need to try other databases. @wraps(function) - def wrapper(self: PGTarget, *args, slotname: str, **kwargs) -> R: - subname = function(self, *args, slotname=slotname, **kwargs) + def wrapper(self: PGTarget, *args, dbname: str, **kwargs) -> R: + subname = function(self, *args, dbname=dbname, **kwargs) + slotname = self.get_replication_object_name( + dbname=dbname, + replication_obj_type=ReplicationObjectType.REPLICATION_SLOT, + ) pg_runner.cleanups.append(partial(_drop_replication_slot, pg_runner_=pg_runner, slot_name=slotname)) diff --git a/test/test_pg_migrate.py b/test/test_pg_migrate.py index e0bc865..18e84bd 100644 --- a/test/test_pg_migrate.py +++ b/test/test_pg_migrate.py @@ -338,7 +338,6 @@ def test_migrate_source_aiven_extras(self, createdb: bool): result: PGMigrateResult = pg_mig.migrate() - assert len(result.pg_databases) == 2 self.assert_result( result=result.pg_databases[dbname], dbname=dbname, diff --git a/test/test_pg_replication.py b/test/test_pg_replication.py index 5b9d09b..4a2aeb3 100644 --- a/test/test_pg_replication.py +++ b/test/test_pg_replication.py @@ -45,19 +45,18 @@ def test_replication(pg_source_and_target: Tuple[PGRunner, PGRunner], aiven_extr pubname = pg_source.create_publication(dbname=dbname) slotname = pg_source.create_replication_slot(dbname=dbname) # verify that pub and replication slot exixts - pub = pg_source.get_publication(dbname=dbname, pubname=pubname) + pub = pg_source.get_publication(dbname=dbname) assert pub assert pub["pubname"] == pubname - slot = pg_source.get_replication_slot(dbname=dbname, slotname=slotname) + slot = pg_source.get_replication_slot(dbname=dbname) assert slot assert slot["slot_name"] == slotname assert slot["slot_type"] == "logical" - subname = pg_target.create_subscription( - conn_str=pg_source.conn_str(dbname=dbname), pubname=pubname, slotname=slotname, dbname=dbname - ) + conn_str = pg_source.conn_str(dbname=dbname) + subname = pg_target.create_subscription(conn_str=conn_str, dbname=dbname) # verify that sub exists - sub = pg_target.get_subscription(dbname=dbname, subname=subname) + sub = pg_target.get_subscription(dbname=dbname) assert sub assert sub["subname"] == subname assert sub["subenabled"] @@ -69,10 +68,8 @@ def test_replication(pg_source_and_target: Tuple[PGRunner, PGRunner], aiven_extr # wait until replication is in sync timer = Timer(timeout=10, what="replication in sync") while timer.loop(): - in_sync, write_lsn = pg_source.replication_in_sync(dbname=dbname, slotname=slotname, max_replication_lag=0) - if in_sync and pg_target.replication_in_sync( - dbname=dbname, subname=subname, write_lsn=write_lsn, max_replication_lag=0 - ): + in_sync, write_lsn = pg_source.replication_in_sync(dbname=dbname, max_replication_lag=0) + if in_sync and pg_target.replication_in_sync(dbname=dbname, write_lsn=write_lsn, max_replication_lag=0): break # verify that all data has been replicated @@ -82,8 +79,8 @@ def test_replication(pg_source_and_target: Tuple[PGRunner, PGRunner], aiven_extr if int(count["count"]) == 5: break - pg_target.cleanup(dbname=dbname, subname=subname) - pg_source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname) + pg_target.cleanup(dbname=dbname) + pg_source.cleanup(dbname=dbname) # verify that pub, replication slot and sub are dropped assert not source.list_pubs(dbname=dbname) @@ -113,7 +110,7 @@ def test_replication_no_aiven_extras_no_superuser(pg_source_and_target: Tuple[PG # creating subscription should fail with insufficient privilege with pytest.raises(psycopg2.ProgrammingError) as err: - pg_target.create_subscription(conn_str=pg_source.conn_str(), pubname="dummy", slotname="dummy", dbname=dbname) + pg_target.create_subscription(conn_str=pg_source.conn_str(), dbname=dbname) assert err.value.pgcode == psycopg2.errorcodes.INSUFFICIENT_PRIVILEGE privilege_error_message = "must be superuser to create subscriptions" diff --git a/test/test_table_filtering.py b/test/test_table_filtering.py index 9ef61e8..72fec4a 100644 --- a/test/test_table_filtering.py +++ b/test/test_table_filtering.py @@ -213,8 +213,6 @@ def test_replicate_filter_with(pg_source_and_target: Tuple[PGRunner, PGRunner], except psycopg2.Error: pass try: - pg_mig.source.cleanup( - dbname=db, pubname=f"managed_db_migrate_{db}_pub", slotname=f"managed_db_migrate_{db}_slot" - ) + pg_mig.source.cleanup(dbname=db) except: # pylint: disable=bare-except pass