Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-47770: Fix threadsafety of sqlalchemy MetaData access #1132

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ..._named import NamedValueAbstractSet
from ..._timespan import Timespan
from ...timespan_database_representation import TimespanDatabaseRepresentation
from ..interfaces import Database
from ..interfaces import Database, DatabaseMetadata


class PostgresqlDatabase(Database):
Expand Down Expand Up @@ -124,7 +124,7 @@ def _init(
namespace: str | None = None,
writeable: bool = True,
dbname: str,
metadata: sqlalchemy.schema.MetaData | None,
metadata: DatabaseMetadata | None,
pg_version: tuple[int, int],
) -> None:
# Initialization logic shared between ``__init__`` and ``clone``.
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/registry/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from ... import ddl
from ..._named import NamedValueAbstractSet
from ..interfaces import Database, StaticTablesContext
from ..interfaces import Database, DatabaseMetadata, StaticTablesContext


def _onSqlite3Connect(
Expand Down Expand Up @@ -109,7 +109,7 @@ def _init(
namespace: str | None = None,
writeable: bool = True,
filename: str | None,
metadata: sqlalchemy.schema.MetaData | None,
metadata: DatabaseMetadata | None,
) -> None:
# Initialization logic shared between ``__init__`` and ``clone``.
super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata)
Expand Down
173 changes: 118 additions & 55 deletions python/lsst/daf/butler/registry/interfaces/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

__all__ = [
"Database",
"DatabaseMetadata",
"ReadOnlyDatabaseError",
"DatabaseConflictError",
"DatabaseInsertMode",
Expand All @@ -46,6 +47,7 @@
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from threading import Lock
from typing import Any, cast, final

import astropy.time
Expand Down Expand Up @@ -136,7 +138,6 @@

def __init__(self, db: Database, connection: sqlalchemy.engine.Connection):
self._db = db
self._foreignKeys: list[tuple[sqlalchemy.schema.Table, sqlalchemy.schema.ForeignKeyConstraint]] = []
self._inspector = sqlalchemy.inspect(connection)
self._tableNames = frozenset(self._inspector.get_table_names(schema=self._db.namespace))
self._initializers: list[Callable[[Database], None]] = []
Expand Down Expand Up @@ -164,13 +165,9 @@
to be declared in any order even in the presence of foreign key
relationships.
"""
name = self._db._mangleTableName(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused, do you want DatabaseMetadata to handle table name mangling internally? I do not think it is done consistently now, you call add_table with the original name, but get_table is called with mangled name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just going to remove it entirely. It's a no-op -- the only implementation returns the name unchanged. I think it is a holdover from the Oracle implementation.

The existing implementation was already incorrect because there were several places where name mangling was applied multiple times. I noticed that the two cases I already removed were double-applying the mangling (because it is done internally in _convertTableSpec) so I had removed those, but apparently the logic is also wrong in a lot of other places.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My recollection is that it was PostgreSQL, not Oracle, that needed name shrinking. It may have been index and constraint names more than table names where it was important, but I'm pretty sure PostgreSQL is the one with a 64-char limit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two different things. The name shrinking logic is still there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry for the noise.

metadata = self._db._metadata
assert metadata is not None, "Guaranteed by context manager that returns this object."
table = self._db._convertTableSpec(name, spec, metadata)
for foreignKeySpec in spec.foreignKeys:
self._foreignKeys.append((table, self._db._convertForeignKeySpec(name, foreignKeySpec, metadata)))
return table
return metadata.add_table(self._db, name, spec)

def addTableTuple(self, specs: tuple[ddl.TableSpec, ...]) -> tuple[sqlalchemy.schema.Table, ...]:
"""Add a named tuple of tables to the schema, returning their
Expand Down Expand Up @@ -273,15 +270,15 @@
origin: int,
engine: sqlalchemy.engine.Engine,
namespace: str | None = None,
metadata: sqlalchemy.schema.MetaData | None = None,
metadata: DatabaseMetadata | None = None,
):
self.origin = origin
self.name_shrinker = NameShrinker(engine.dialect.max_identifier_length)
self.namespace = namespace
self._engine = engine
self._session_connection: sqlalchemy.engine.Connection | None = None
self._metadata = metadata
self._temp_tables: set[str] = set()
self._metadata = metadata

def __repr__(self) -> str:
# Rather than try to reproduce all the parameters used to create
Expand Down Expand Up @@ -540,6 +537,7 @@
otherwise, but in that case they probably need to be modified to
support the full range of expected read-only butler behavior.
"""
assert self._metadata is not None, "Static tables must be created before temporary tables"
with self._session() as connection:
table = self._make_temporary_table(connection, spec=spec, name=name)
self._temp_tables.add(table.key)
Expand All @@ -549,6 +547,7 @@
with self._transaction():
table.drop(connection)
self._temp_tables.remove(table.key)
self._metadata.remove_table(table.name)

@contextmanager
def _session(self) -> Iterator[sqlalchemy.engine.Connection]:
Expand Down Expand Up @@ -760,7 +759,8 @@
"""
if create and not self.isWriteable():
raise ReadOnlyDatabaseError(f"Cannot create tables in read-only database {self}.")
self._metadata = sqlalchemy.MetaData(schema=self.namespace)

self._metadata = DatabaseMetadata(self.namespace)
try:
with self._transaction() as (_, connection):
context = StaticTablesContext(self, connection)
Expand All @@ -770,8 +770,6 @@
# do anything in this case
raise SchemaAlreadyDefinedError(f"Cannot create tables in non-empty database {self}.")
yield context
for table, foreignKey in context._foreignKeys:
table.append_constraint(foreignKey)
if create:
if (
self.namespace is not None
Expand Down Expand Up @@ -858,30 +856,6 @@
"""
return shrunk

def _mangleTableName(self, name: str) -> str:
"""Map a logical, user-visible table name to the true table name used
in the database.

The default implementation returns the given name unchanged.

Parameters
----------
name : `str`
Input table name. Should not include a namespace (i.e. schema)
prefix.

Returns
-------
mangled : `str`
Mangled version of the table name (still with no namespace prefix).

Notes
-----
Reimplementations of this method must be idempotent - mangling an
already-mangled name must have no effect.
"""
return name

def _makeColumnConstraints(self, table: str, spec: ddl.FieldSpec) -> list[sqlalchemy.CheckConstraint]:
"""Create constraints based on this spec.

Expand Down Expand Up @@ -974,13 +948,11 @@
SQLAlchemy representation of the constraint.
"""
name = self.shrinkDatabaseEntityName(
"_".join(
["fkey", table, self._mangleTableName(spec.table)] + list(spec.target) + list(spec.source)
)
"_".join(["fkey", table, spec.table] + list(spec.target) + list(spec.source))
)
return sqlalchemy.schema.ForeignKeyConstraint(
spec.source,
[f"{self._mangleTableName(spec.table)}.{col}" for col in spec.target],
[f"{spec.table}.{col}" for col in spec.target],
name=name,
ondelete=spec.onDelete,
)
Expand Down Expand Up @@ -1050,7 +1022,6 @@
avoid circular dependencies. These are added by higher-level logic in
`ensureTableExists`, `getExistingTable`, and `declareStaticTables`.
"""
name = self._mangleTableName(name)
args: list[sqlalchemy.schema.SchemaItem] = [
self._convertFieldSpec(name, fieldSpec, metadata) for fieldSpec in spec.fields
]
Expand Down Expand Up @@ -1141,9 +1112,8 @@
raise ReadOnlyDatabaseError(
f"Table {name} does not exist, and cannot be created because database {self} is read-only."
)
table = self._convertTableSpec(name, spec, self._metadata)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata))

table = self._metadata.add_table(self, name, spec)
try:
with self._transaction() as (_, connection):
table.create(connection)
Expand Down Expand Up @@ -1191,8 +1161,7 @@
Subclasses may override this method, but usually should not need to.
"""
assert self._metadata is not None, "Static tables must be declared before dynamic tables."
name = self._mangleTableName(name)
table = self._metadata.tables.get(name if self.namespace is None else f"{self.namespace}.{name}")
table = self._metadata.get_table(name)
if table is not None:
if spec.fields.names != set(table.columns.keys()):
raise DatabaseConflictError(
Expand All @@ -1206,10 +1175,7 @@
)
if name in inspector.get_table_names(schema=self.namespace):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but this won't work if name is not mangled?

_checkExistingTableDefinition(name, spec, inspector.get_columns(name, schema=self.namespace))
table = self._convertTableSpec(name, spec, self._metadata)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, self._metadata))
return table
return self._metadata.add_table(self, name, spec)
return table

def _make_temporary_table(
Expand Down Expand Up @@ -1244,19 +1210,16 @@
"""
if name is None:
name = f"tmp_{uuid.uuid4().hex}"
metadata = self._metadata
if metadata is None:
if self._metadata is None:
raise RuntimeError("Cannot create temporary table before static schema is defined.")
table = self._convertTableSpec(
name, spec, metadata, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs
table = self._metadata.add_table(
self, name, spec, prefixes=["TEMPORARY"], schema=sqlalchemy.schema.BLANK_SCHEMA, **kwargs
)
if table.key in self._temp_tables and table.key != name:
raise ValueError(
f"A temporary table with name {name} (transformed to {table.key} by "
"Database) already exists."
)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(self._convertForeignKeySpec(name, foreignKeySpec, metadata))
with self._transaction():
table.create(connection)
return table
Expand Down Expand Up @@ -2010,3 +1973,103 @@
"""An object that can be used to shrink field names to fit within the
identifier limit of the database engine (`NameShrinker`).
"""


class DatabaseMetadata:
"""Wrapper around SqlAlchemy MetaData object to ensure threadsafety.

Parameters
----------
namespace : `str` or `None`
Name of the schema or namespace this instance is associated with.

Notes
-----
`sqlalchemy.MetaData` is documented to be threadsafe for reads, but not
with concurrent modifications. We add tables dynamically at runtime,
and the MetaData object is shared by all Database instances sharing
the same connection pool.
"""

def __init__(self, namespace: str | None) -> None:
self._lock = Lock()
self._metadata = sqlalchemy.MetaData(schema=namespace)
self._tables: dict[str, sqlalchemy.Table] = {}

def add_table(
self, db: Database, name: str, spec: ddl.TableSpec, **kwargs: Any
) -> sqlalchemy.schema.Table:
"""Add a new table to the MetaData object, returning its sqlalchemy
representation. This does not physically create the table in the
database -- it only sets up its definition.

Parameters
----------
db : `Database`
Database connection associated with the table definition.
name : `str`
The name of the table.
spec : `ddl.TableSpec`
The specification of the table.
**kwargs
Additional keyword arguments to forward to the
`sqlalchemy.schema.Table` constructor.

Returns
-------
table : `sqlalchemy.schema.Table`
The created table.
"""
with self._lock:
if (table := self._tables.get(name)) is not None:
return table

Check warning on line 2025 in python/lsst/daf/butler/registry/interfaces/_database.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/registry/interfaces/_database.py#L2025

Added line #L2025 was not covered by tests

table = db._convertTableSpec(name, spec, self._metadata, **kwargs)
for foreignKeySpec in spec.foreignKeys:
table.append_constraint(db._convertForeignKeySpec(name, foreignKeySpec, self._metadata))

self._tables[name] = table
return table

def get_table(self, name: str) -> sqlalchemy.schema.Table | None:
"""Return the definition of a table that was previously added to this
MetaData object.

Parameters
----------
name : `str`
Name of the table.

Returns
-------
table : `sqlalchemy.schema.Table` or `None`
The table definition, or `None` if the table is not known to this
MetaData instance.
"""
with self._lock:
return self._tables.get(name)

def remove_table(self, name: str) -> None:
"""Remove a table that was previously added to this MetaData object.

Parameters
----------
name : `str`
Name of the table.
"""
with self._lock:
table = self._tables.pop(name, None)
if table is not None:
self._metadata.remove(table)

def create_all(self, connection: sqlalchemy.engine.Connection) -> None:
"""Create all tables known to this MetaData object in the database.
Same as `sqlalchemy.MetaData.create_all`.

Parameters
----------
connection : `sqlalchemy.engine.connection`
Database connection that will be used to create tables.
"""
with self._lock:
self._metadata.create_all(connection)
10 changes: 5 additions & 5 deletions tests/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def isEmptyDatabaseActuallyWriteable(database: SqliteDatabase) -> bool:
"a", ddl.TableSpec(fields=[ddl.FieldSpec("b", dtype=sqlalchemy.Integer, primaryKey=True)])
)
# Drop created table so that schema remains empty.
database._metadata.drop_all(database._engine, tables=[table])
database._metadata._metadata.drop_all(database._engine, tables=[table])
return True
except Exception:
return False
Expand Down Expand Up @@ -103,13 +103,13 @@ def testConnection(self):
_, filename = tempfile.mkstemp(dir=self.root, suffix=".sqlite3")
# Create a read-write database by passing in the filename.
rwFromFilename = SqliteDatabase.fromEngine(SqliteDatabase.makeEngine(filename=filename), origin=0)
self.assertEqual(rwFromFilename.filename, filename)
self.assertEqual(os.path.realpath(rwFromFilename.filename), os.path.realpath(filename))
self.assertEqual(rwFromFilename.origin, 0)
self.assertTrue(rwFromFilename.isWriteable())
self.assertTrue(isEmptyDatabaseActuallyWriteable(rwFromFilename))
# Create a read-write database via a URI.
rwFromUri = SqliteDatabase.fromUri(f"sqlite:///{filename}", origin=0)
self.assertEqual(rwFromUri.filename, filename)
self.assertEqual(os.path.realpath(rwFromUri.filename), os.path.realpath(filename))
self.assertEqual(rwFromUri.origin, 0)
self.assertTrue(rwFromUri.isWriteable())
self.assertTrue(isEmptyDatabaseActuallyWriteable(rwFromUri))
Expand All @@ -123,13 +123,13 @@ def testConnection(self):
roFromFilename = SqliteDatabase.fromEngine(
SqliteDatabase.makeEngine(filename=filename), origin=0, writeable=False
)
self.assertEqual(roFromFilename.filename, filename)
self.assertEqual(os.path.realpath(roFromFilename.filename), os.path.realpath(filename))
self.assertEqual(roFromFilename.origin, 0)
self.assertFalse(roFromFilename.isWriteable())
self.assertFalse(isEmptyDatabaseActuallyWriteable(roFromFilename))
# Create a read-write database via a URI.
roFromUri = SqliteDatabase.fromUri(f"sqlite:///{filename}", origin=0, writeable=False)
self.assertEqual(roFromUri.filename, filename)
self.assertEqual(os.path.realpath(roFromUri.filename), os.path.realpath(filename))
self.assertEqual(roFromUri.origin, 0)
self.assertFalse(roFromUri.isWriteable())
self.assertFalse(isEmptyDatabaseActuallyWriteable(roFromUri))
Expand Down
Loading