Skip to content

Commit

Permalink
Improve replication object clean-up on failure.
Browse files Browse the repository at this point in the history
Added some improvements on error handling whenever pg_migrate is cleaning any replication object (publication, subscription, replication slot). As our current implementation does not drop objects whenever create_<repl_object_type> fails (raises an exception). Current clean-up is dependent on the object name returned by those functions, but if nothing is returned then nothing will be cleaned. The improvement makes the cleanup independent from the return of such functions, as it nows fetches the replication object name and verifies it exists or not, instead of relying on return values.
  • Loading branch information
kathia-barahona committed Jan 20, 2025
1 parent 136dcf8 commit 5dc933d
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 59 deletions.
180 changes: 129 additions & 51 deletions aiven_db_migrate/migrate/pgmigrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@
MAX_CLI_LEN = 2097152 # getconf ARG_MAX


class ReplicationObjectType(enum.StrEnum):
PUBLICATION = "pub"
SUBSCRIPTION = "sub"
REPLICATION_SLOT = "slot"

def get_display_name(self) -> str:
return self.name.replace("_", " ").lower()


@dataclass
class PGExtension:
name: str
Expand Down Expand Up @@ -97,6 +106,7 @@ class PGRole:

class PGCluster:
"""PGCluster is a collection of databases managed by a single PostgreSQL server instance"""
DB_OBJECT_PREFIX = "managed_db_migrate"
conn_info: Dict[str, Any]
_databases: Dict[str, PGDatabase]
_params: Dict[str, str]
Expand Down Expand Up @@ -438,6 +448,10 @@ def mangle_db_name(self, db_name: str) -> str:
return db_name
return hashlib.md5(db_name.encode()).hexdigest()

def get_replication_object_name(self, dbname: str, replication_obj_type: ReplicationObjectType) -> str:
mangled_name = self.mangle_db_name(dbname)
return f"{self.DB_OBJECT_PREFIX}_{mangled_name}_{replication_obj_type.value}"


class PGSource(PGCluster):
"""Source PostgreSQL cluster"""
Expand All @@ -454,8 +468,10 @@ def get_size(self, *, dbname, only_tables: Optional[List[str]] = None) -> float:
return result[0]["size"] or 0

def create_publication(self, *, dbname: str, only_tables: Optional[List[str]] = None) -> str:
mangled_name = self.mangle_db_name(dbname)
pubname = f"managed_db_migrate_{mangled_name}_pub"
pubname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.PUBLICATION,
)
validate_pg_identifier_length(pubname)

pub_options: Union[List[str], str]
Expand Down Expand Up @@ -498,8 +514,11 @@ def create_publication(self, *, dbname: str, only_tables: Optional[List[str]] =
return pubname

def create_replication_slot(self, *, dbname: str) -> str:
mangled_name = self.mangle_db_name(dbname)
slotname = f"managed_db_migrate_{mangled_name}_slot"
slotname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)

validate_pg_identifier_length(slotname)

self.log.info("Creating replication slot %r in database %r", slotname, dbname)
Expand All @@ -516,12 +535,17 @@ def create_replication_slot(self, *, dbname: str) -> str:

return slotname

def get_publication(self, *, dbname: str, pubname: str) -> Dict[str, Any]:
def get_publication(self, *, dbname: str) -> Dict[str, Any]:
pubname = self.get_replication_object_name(dbname=dbname, replication_obj_type=ReplicationObjectType.PUBLICATION)
# publications as per database so connect to given database
result = self.c("SELECT * FROM pg_catalog.pg_publication WHERE pubname = %s", args=(pubname, ), dbname=dbname)
return result[0] if result else {}

def get_replication_slot(self, *, dbname: str, slotname: str) -> Dict[str, Any]:
def get_replication_slot(self, *, dbname: str) -> Dict[str, Any]:
slotname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)
result = self.c(
"SELECT * from pg_catalog.pg_replication_slots WHERE database = %s AND slot_name = %s",
args=(
Expand Down Expand Up @@ -573,28 +597,61 @@ def large_objects_present(self, *, dbname: str) -> bool:
self.log.warning("Unable to determine if large objects present in database %r", dbname)
return False

def cleanup(self, *, dbname: str, pubname: str, slotname: str):
# publications as per database so connect to correct database
pub = self.get_publication(dbname=dbname, pubname=pubname)
if pub:
self.log.info("Dropping publication %r from database %r", pub, dbname)
self.c("DROP PUBLICATION {}".format(pub["pubname"]), dbname=dbname, return_rows=0)
slot = self.get_replication_slot(dbname=dbname, slotname=slotname)
if slot:
self.log.info("Dropping replication slot %r from database %r", slot, dbname)
self.c(
"SELECT 1 FROM pg_catalog.pg_drop_replication_slot(%s)",
args=(slot["slot_name"], ),
dbname=dbname,
return_rows=0
def cleanup(self, *, dbname: str):
self._cleanup_replication_object(dbname=dbname, replication_object_type=ReplicationObjectType.PUBLICATION)
self._cleanup_replication_object(dbname=dbname, replication_object_type=ReplicationObjectType.REPLICATION_SLOT)

def _cleanup_replication_object(self, dbname: str, replication_object_type: ReplicationObjectType):
rep_obj_type_display_name = replication_object_type.display_name()

rep_obj_name = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=replication_object_type,
)
try:
if ReplicationObjectType.PUBLICATION is replication_object_type:
rep_obj = self.get_publication(dbname=dbname)
delete_query = f"DROP PUBLICATION %s"
else:
rep_obj = self.get_replication_slot(dbname=dbname)
delete_query = f"SELECT 1 FROM pg_catalog.pg_drop_replication_slot(%s)"

if not rep_obj:
return

self.log.info(
"Dropping %r %r from database %r",
rep_obj_type_display_name,
rep_obj_name,
dbname,
)
self.c(delete_query, args=(rep_obj_name, ), dbname=dbname, return_rows=0)
except Exception as exc:
self.log.error(
"Failed to drop %r %r for database %r: %s",
rep_obj_type_display_name,
rep_obj_name,
dbname,
exc,
)


class PGTarget(PGCluster):
"""Target PostgreSQL cluster"""
def create_subscription(self, *, conn_str: str, pubname: str, slotname: str, dbname: str) -> str:
mangled_name = self.mangle_db_name(dbname)
subname = f"managed_db_migrate_{mangled_name}_sub"
def create_subscription(self, *, conn_str: str, dbname: str) -> str:
pubname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.PUBLICATION,
)
slotname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)

subname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)
validate_pg_identifier_length(subname)

has_aiven_extras = self.has_aiven_extras(dbname=dbname)
Expand Down Expand Up @@ -630,7 +687,11 @@ def create_subscription(self, *, conn_str: str, pubname: str, slotname: str, dbn

return subname

def get_subscription(self, *, dbname: str, subname: str) -> Dict[str, Any]:
def get_subscription(self, *, dbname: str) -> Dict[str, Any]:
subname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)
if self.has_aiven_extras(dbname=dbname):
result = self.c(
"SELECT * FROM aiven_extras.pg_list_all_subscriptions() WHERE subname = %s", args=(subname, ), dbname=dbname
Expand Down Expand Up @@ -664,23 +725,32 @@ def replication_in_sync(self, *, dbname: str, subname: str, write_lsn: str, max_
self.log.warning("Replication status not available for %r in database %r", subname, dbname)
return False

def cleanup(self, *, dbname: str, subname: str):
sub = self.get_subscription(dbname=dbname, subname=subname)
if sub:
self.log.info("Dropping subscription %r from database %r", sub["subname"], dbname)
def cleanup(self, *, dbname: str):
subname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)
try:
if not self.get_subscription(dbname=dbname):
return

self.log.info("Dropping subscription %r from database %r", subname, dbname)
if self.has_aiven_extras(dbname=dbname):
# NOTE: this drops also replication slot from source
self.c(
"SELECT * FROM aiven_extras.pg_drop_subscription(%s)",
args=(sub["subname"], ),
dbname=dbname,
return_rows=0
)
self.c("SELECT * FROM aiven_extras.pg_drop_subscription(%s)", args=(subname, ), dbname=dbname, return_rows=0)
else:
# requires superuser or superuser-like privileges, such as "rds_replication" role in AWS RDS
self.c("ALTER SUBSCRIPTION {} DISABLE".format(sub["subname"]), dbname=dbname, return_rows=0)
self.c("ALTER SUBSCRIPTION {} SET (slot_name = NONE)".format(sub["subname"]), dbname=dbname, return_rows=0)
self.c("DROP SUBSCRIPTION {}".format(sub["subname"]), dbname=dbname, return_rows=0)
self.c("ALTER SUBSCRIPTION {} DISABLE".format(subname), dbname=dbname, return_rows=0)
self.c("ALTER SUBSCRIPTION {} SET (slot_name = NONE)".format(subname), dbname=dbname, return_rows=0)
self.c("DROP SUBSCRIPTION {}".format(subname), dbname=dbname, return_rows=0)

except Exception as exc:
self.log.error(
"Failed to drop subscription %r for database %r: %s",
subname,
dbname,
exc,
)


@enum.unique
Expand Down Expand Up @@ -1256,7 +1326,16 @@ def _dump_data(self, *, db: PGDatabase) -> PGMigrateStatus:
raise PGDataDumpFailedError(f"Failed to dump data: {subtask!r}")
return PGMigrateStatus.done

def _wait_for_replication(self, *, dbname: str, slotname: str, subname: str, check_interval: float = 2.0):
def _wait_for_replication(self, *, dbname: str, check_interval: float = 2.0):
slotname = self.source.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)
subname = self.target.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)

while True:
in_sync, write_lsn = self.source.replication_in_sync(
dbname=dbname, slotname=slotname, max_replication_lag=self.max_replication_lag
Expand All @@ -1269,29 +1348,28 @@ def _wait_for_replication(self, *, dbname: str, slotname: str, subname: str, che

def _db_replication(self, *, db: PGDatabase) -> PGMigrateStatus:
dbname = db.dbname
pubname = slotname = subname = None
try:
tables = self.filter_tables(db) or []
pubname = self.source.create_publication(dbname=dbname, only_tables=tables)
slotname = self.source.create_replication_slot(dbname=dbname)
subname = self.target.create_subscription(
conn_str=self.source.conn_str(dbname=dbname), pubname=pubname, slotname=slotname, dbname=dbname
)
self.source.create_publication(dbname=dbname, only_tables=tables)
self.source.create_replication_slot(dbname=dbname)
self.target.create_subscription(conn_str=self.source.conn_str(dbname=dbname), dbname=dbname)

except psycopg2.ProgrammingError as e:
self.log.error("Encountered error: %r, cleaning up", e)
if subname:
self.target.cleanup(dbname=dbname, subname=subname)
if pubname and slotname:
self.source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname)

# clean-up replication objects, avoid leaving traces specially in source
self.target.cleanup(dbname=dbname)
self.source.cleanup(dbname=dbname)
raise

self.log.info("Logical replication setup successful for database %r", dbname)
if self.max_replication_lag > -1:
self._wait_for_replication(dbname=dbname, slotname=slotname, subname=subname)
self._wait_for_replication(dbname=dbname)
if self.stop_replication:
self.target.cleanup(dbname=dbname, subname=subname)
self.source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname)
self.target.cleanup(dbname=dbname)
self.source.cleanup(dbname=dbname)
return PGMigrateStatus.done

# leaving replication running
return PGMigrateStatus.running

Expand Down
14 changes: 6 additions & 8 deletions test/test_pg_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,17 @@ def test_replication(pg_source_and_target: Tuple[PGRunner, PGRunner], aiven_extr
pubname = pg_source.create_publication(dbname=dbname)
slotname = pg_source.create_replication_slot(dbname=dbname)
# verify that pub and replication slot exixts
pub = pg_source.get_publication(dbname=dbname, pubname=pubname)
pub = pg_source.get_publication(dbname=dbname)
assert pub
assert pub["pubname"] == pubname
slot = pg_source.get_replication_slot(dbname=dbname, slotname=slotname)
slot = pg_source.get_replication_slot(dbname=dbname)
assert slot
assert slot["slot_name"] == slotname
assert slot["slot_type"] == "logical"

subname = pg_target.create_subscription(
conn_str=pg_source.conn_str(dbname=dbname), pubname=pubname, slotname=slotname, dbname=dbname
)
subname = pg_target.create_subscription(conn_str=pg_source.conn_str(dbname=dbname), dbname=dbname)
# verify that sub exists
sub = pg_target.get_subscription(dbname=dbname, subname=subname)
sub = pg_target.get_subscription(dbname=dbname)
assert sub
assert sub["subname"] == subname
assert sub["subenabled"]
Expand All @@ -81,8 +79,8 @@ def test_replication(pg_source_and_target: Tuple[PGRunner, PGRunner], aiven_extr
if int(count["count"]) == 5:
break

pg_target.cleanup(dbname=dbname, subname=subname)
pg_source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname)
pg_target.cleanup(dbname=dbname)
pg_source.cleanup(dbname=dbname)

# verify that pub, replication slot and sub are dropped
assert not source.list_pubs(dbname=dbname)
Expand Down

0 comments on commit 5dc933d

Please sign in to comment.