Skip to content

Commit

Permalink
Improve replication object clean-up on failure.
Browse files Browse the repository at this point in the history
Added some improvements on error handling whenever pg_migrate is cleaning any replication object (publication, subscription, replication slot). As our current implementation does not drop objects whenever create_<repl_object_type> fails (raises an exception). Current clean-up is dependent on the object name returned by those functions, but if nothing is returned then nothing will be cleaned. The improvement makes the cleanup independent from the return of such functions, as it nows fetches the replication object name and verifies it exists or not, instead of relying on return values.
  • Loading branch information
kathia-barahona committed Jan 23, 2025
1 parent 4758a17 commit 427112f
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 78 deletions.
200 changes: 143 additions & 57 deletions aiven_db_migrate/migrate/pgmigrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@
MAX_CLI_LEN = 2097152 # getconf ARG_MAX


class ReplicationObjectType(enum.Enum):
PUBLICATION = "pub"
SUBSCRIPTION = "sub"
REPLICATION_SLOT = "slot"

def get_display_name(self) -> str:
return self.name.replace("_", " ").lower()


@dataclass
class PGExtension:
name: str
Expand Down Expand Up @@ -97,6 +106,7 @@ class PGRole:

class PGCluster:
"""PGCluster is a collection of databases managed by a single PostgreSQL server instance"""
DB_OBJECT_PREFIX = "managed_db_migrate"
conn_info: Dict[str, Any]
_databases: Dict[str, PGDatabase]
_params: Dict[str, str]
Expand Down Expand Up @@ -438,6 +448,10 @@ def mangle_db_name(self, db_name: str) -> str:
return db_name
return hashlib.md5(db_name.encode()).hexdigest()

def get_replication_object_name(self, dbname: str, replication_obj_type: ReplicationObjectType) -> str:
mangled_name = self.mangle_db_name(dbname)
return f"{self.DB_OBJECT_PREFIX}_{mangled_name}_{replication_obj_type.value}"


class PGSource(PGCluster):
"""Source PostgreSQL cluster"""
Expand All @@ -454,8 +468,10 @@ def get_size(self, *, dbname, only_tables: Optional[List[str]] = None) -> float:
return result[0]["size"] or 0

def create_publication(self, *, dbname: str, only_tables: Optional[List[str]] = None) -> str:
mangled_name = self.mangle_db_name(dbname)
pubname = f"managed_db_migrate_{mangled_name}_pub"
pubname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.PUBLICATION,
)
validate_pg_identifier_length(pubname)

pub_options: Union[List[str], str]
Expand Down Expand Up @@ -498,8 +514,11 @@ def create_publication(self, *, dbname: str, only_tables: Optional[List[str]] =
return pubname

def create_replication_slot(self, *, dbname: str) -> str:
mangled_name = self.mangle_db_name(dbname)
slotname = f"managed_db_migrate_{mangled_name}_slot"
slotname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)

validate_pg_identifier_length(slotname)

self.log.info("Creating replication slot %r in database %r", slotname, dbname)
Expand All @@ -516,12 +535,17 @@ def create_replication_slot(self, *, dbname: str) -> str:

return slotname

def get_publication(self, *, dbname: str, pubname: str) -> Dict[str, Any]:
def get_publication(self, *, dbname: str) -> Dict[str, Any]:
pubname = self.get_replication_object_name(dbname=dbname, replication_obj_type=ReplicationObjectType.PUBLICATION)
# publications as per database so connect to given database
result = self.c("SELECT * FROM pg_catalog.pg_publication WHERE pubname = %s", args=(pubname, ), dbname=dbname)
return result[0] if result else {}

def get_replication_slot(self, *, dbname: str, slotname: str) -> Dict[str, Any]:
def get_replication_slot(self, *, dbname: str) -> Dict[str, Any]:
slotname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)
result = self.c(
"SELECT * from pg_catalog.pg_replication_slots WHERE database = %s AND slot_name = %s",
args=(
Expand All @@ -532,7 +556,11 @@ def get_replication_slot(self, *, dbname: str, slotname: str) -> Dict[str, Any]:
)
return result[0] if result else {}

def replication_in_sync(self, *, dbname: str, slotname: str, max_replication_lag: int) -> Tuple[bool, str]:
def replication_in_sync(self, *, dbname: str, max_replication_lag: int) -> Tuple[bool, str]:
slotname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)
exists = self.c(
"SELECT 1 FROM pg_catalog.pg_replication_slots WHERE slot_name = %s", args=(slotname, ), dbname=dbname
)
Expand Down Expand Up @@ -573,28 +601,63 @@ def large_objects_present(self, *, dbname: str) -> bool:
self.log.warning("Unable to determine if large objects present in database %r", dbname)
return False

def cleanup(self, *, dbname: str, pubname: str, slotname: str):
# publications as per database so connect to correct database
pub = self.get_publication(dbname=dbname, pubname=pubname)
if pub:
self.log.info("Dropping publication %r from database %r", pub, dbname)
self.c("DROP PUBLICATION {}".format(pub["pubname"]), dbname=dbname, return_rows=0)
slot = self.get_replication_slot(dbname=dbname, slotname=slotname)
if slot:
self.log.info("Dropping replication slot %r from database %r", slot, dbname)
self.c(
"SELECT 1 FROM pg_catalog.pg_drop_replication_slot(%s)",
args=(slot["slot_name"], ),
dbname=dbname,
return_rows=0
def cleanup(self, *, dbname: str):
self._cleanup_replication_object(dbname=dbname, replication_object_type=ReplicationObjectType.PUBLICATION)
self._cleanup_replication_object(dbname=dbname, replication_object_type=ReplicationObjectType.REPLICATION_SLOT)

def _cleanup_replication_object(self, dbname: str, replication_object_type: ReplicationObjectType):
rep_obj_type_display_name = replication_object_type.get_display_name()

rep_obj_name = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=replication_object_type,
)
try:
if ReplicationObjectType.PUBLICATION is replication_object_type:
rep_obj = self.get_publication(dbname=dbname)
delete_query = f"DROP PUBLICATION {rep_obj_name};"
args = ()
else:
rep_obj = self.get_replication_slot(dbname=dbname)
delete_query = f"SELECT 1 FROM pg_catalog.pg_drop_replication_slot(%s)"
args = (rep_obj_name, )

if not rep_obj:
return

self.log.info(
"Dropping %r %r from database %r",
rep_obj_type_display_name,
rep_obj_name,
dbname,
)
self.c(delete_query, args=args, dbname=dbname, return_rows=0)
except Exception as exc:
self.log.error(
"Failed to drop %r %r for database %r: %s",
rep_obj_type_display_name,
rep_obj_name,
dbname,
exc,
)


class PGTarget(PGCluster):
"""Target PostgreSQL cluster"""
def create_subscription(self, *, conn_str: str, pubname: str, slotname: str, dbname: str) -> str:
mangled_name = self.mangle_db_name(dbname)
subname = f"managed_db_migrate_{mangled_name}_sub"
def create_subscription(self, *, conn_str: str, dbname: str) -> str:
pubname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.PUBLICATION,
)
slotname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)

subname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)
validate_pg_identifier_length(subname)

has_aiven_extras = self.has_aiven_extras(dbname=dbname)
Expand Down Expand Up @@ -630,7 +693,11 @@ def create_subscription(self, *, conn_str: str, pubname: str, slotname: str, dbn

return subname

def get_subscription(self, *, dbname: str, subname: str) -> Dict[str, Any]:
def get_subscription(self, *, dbname: str) -> Dict[str, Any]:
subname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)
if self.has_aiven_extras(dbname=dbname):
result = self.c(
"SELECT * FROM aiven_extras.pg_list_all_subscriptions() WHERE subname = %s", args=(subname, ), dbname=dbname
Expand All @@ -640,7 +707,11 @@ def get_subscription(self, *, dbname: str, subname: str) -> Dict[str, Any]:
result = self.c("SELECT * FROM pg_catalog.pg_subscription WHERE subname = %s", args=(subname, ), dbname=dbname)
return result[0] if result else {}

def replication_in_sync(self, *, dbname: str, subname: str, write_lsn: str, max_replication_lag: int) -> bool:
def replication_in_sync(self, *, dbname: str, write_lsn: str, max_replication_lag: int) -> bool:
subname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)
status = self.c(
"""
SELECT stat.*,
Expand All @@ -664,23 +735,32 @@ def replication_in_sync(self, *, dbname: str, subname: str, write_lsn: str, max_
self.log.warning("Replication status not available for %r in database %r", subname, dbname)
return False

def cleanup(self, *, dbname: str, subname: str):
sub = self.get_subscription(dbname=dbname, subname=subname)
if sub:
self.log.info("Dropping subscription %r from database %r", sub["subname"], dbname)
def cleanup(self, *, dbname: str):
subname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)
try:
if not self.get_subscription(dbname=dbname):
return

self.log.info("Dropping subscription %r from database %r", subname, dbname)
if self.has_aiven_extras(dbname=dbname):
# NOTE: this drops also replication slot from source
self.c(
"SELECT * FROM aiven_extras.pg_drop_subscription(%s)",
args=(sub["subname"], ),
dbname=dbname,
return_rows=0
)
self.c("SELECT * FROM aiven_extras.pg_drop_subscription(%s)", args=(subname, ), dbname=dbname, return_rows=0)
else:
# requires superuser or superuser-like privileges, such as "rds_replication" role in AWS RDS
self.c("ALTER SUBSCRIPTION {} DISABLE".format(sub["subname"]), dbname=dbname, return_rows=0)
self.c("ALTER SUBSCRIPTION {} SET (slot_name = NONE)".format(sub["subname"]), dbname=dbname, return_rows=0)
self.c("DROP SUBSCRIPTION {}".format(sub["subname"]), dbname=dbname, return_rows=0)
self.c("ALTER SUBSCRIPTION {} DISABLE".format(subname), dbname=dbname, return_rows=0)
self.c("ALTER SUBSCRIPTION {} SET (slot_name = NONE)".format(subname), dbname=dbname, return_rows=0)
self.c("DROP SUBSCRIPTION {}".format(subname), dbname=dbname, return_rows=0)

except Exception as exc:
self.log.error(
"Failed to drop subscription %r for database %r: %s",
subname,
dbname,
exc,
)


@enum.unique
Expand Down Expand Up @@ -1256,42 +1336,48 @@ def _dump_data(self, *, db: PGDatabase) -> PGMigrateStatus:
raise PGDataDumpFailedError(f"Failed to dump data: {subtask!r}")
return PGMigrateStatus.done

def _wait_for_replication(self, *, dbname: str, slotname: str, subname: str, check_interval: float = 2.0):
def _wait_for_replication(self, *, dbname: str, check_interval: float = 2.0):
slotname = self.source.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)
subname = self.target.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.SUBSCRIPTION,
)

while True:
in_sync, write_lsn = self.source.replication_in_sync(
dbname=dbname, slotname=slotname, max_replication_lag=self.max_replication_lag
)
in_sync, write_lsn = self.source.replication_in_sync(dbname=dbname, max_replication_lag=self.max_replication_lag)
if in_sync and self.target.replication_in_sync(
dbname=dbname, subname=subname, write_lsn=write_lsn, max_replication_lag=self.max_replication_lag
dbname=dbname, write_lsn=write_lsn, max_replication_lag=self.max_replication_lag
):
break
time.sleep(check_interval)

def _db_replication(self, *, db: PGDatabase) -> PGMigrateStatus:
dbname = db.dbname
pubname = slotname = subname = None
try:
tables = self.filter_tables(db) or []
pubname = self.source.create_publication(dbname=dbname, only_tables=tables)
slotname = self.source.create_replication_slot(dbname=dbname)
subname = self.target.create_subscription(
conn_str=self.source.conn_str(dbname=dbname), pubname=pubname, slotname=slotname, dbname=dbname
)
self.source.create_publication(dbname=dbname, only_tables=tables)
self.source.create_replication_slot(dbname=dbname)
self.target.create_subscription(conn_str=self.source.conn_str(dbname=dbname), dbname=dbname)

except psycopg2.ProgrammingError as e:
self.log.error("Encountered error: %r, cleaning up", e)
if subname:
self.target.cleanup(dbname=dbname, subname=subname)
if pubname and slotname:
self.source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname)

# clean-up replication objects, avoid leaving traces specially in source
self.target.cleanup(dbname=dbname)
self.source.cleanup(dbname=dbname)
raise

self.log.info("Logical replication setup successful for database %r", dbname)
if self.max_replication_lag > -1:
self._wait_for_replication(dbname=dbname, slotname=slotname, subname=subname)
self._wait_for_replication(dbname=dbname)
if self.stop_replication:
self.target.cleanup(dbname=dbname, subname=subname)
self.source.cleanup(dbname=dbname, pubname=pubname, slotname=slotname)
self.target.cleanup(dbname=dbname)
self.source.cleanup(dbname=dbname)
return PGMigrateStatus.done

# leaving replication running
return PGMigrateStatus.running

Expand Down
2 changes: 1 addition & 1 deletion aiven_db_migrate/migrate/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.5-2-gfd83d7c"
__version__ = "0.1.5-2-ga13c553"
10 changes: 7 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from _pytest.fixtures import FixtureRequest
from _pytest.tmpdir import TempPathFactory
from aiven_db_migrate.migrate.pgmigrate import PGTarget
from aiven_db_migrate.migrate.pgmigrate import PGTarget, ReplicationObjectType
from contextlib import contextmanager
from copy import copy
from functools import partial, wraps
Expand Down Expand Up @@ -154,8 +154,12 @@ def _drop_replication_slot(pg_runner_: PGRunner, slot_name: str) -> None:
break # Found it, no need to try other databases.

@wraps(function)
def wrapper(self: PGTarget, *args, slotname: str, **kwargs) -> R:
subname = function(self, *args, slotname=slotname, **kwargs)
def wrapper(self: PGTarget, *args, dbname: str, **kwargs) -> R:
subname = function(self, *args, dbname=dbname, **kwargs)
slotname = self.get_replication_object_name(
dbname=dbname,
replication_obj_type=ReplicationObjectType.REPLICATION_SLOT,
)

pg_runner.cleanups.append(partial(_drop_replication_slot, pg_runner_=pg_runner, slot_name=slotname))

Expand Down
1 change: 0 additions & 1 deletion test/test_pg_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def test_migrate_source_aiven_extras(self, createdb: bool):

result: PGMigrateResult = pg_mig.migrate()

assert len(result.pg_databases) == 2
self.assert_result(
result=result.pg_databases[dbname],
dbname=dbname,
Expand Down
Loading

0 comments on commit 427112f

Please sign in to comment.