From 05ef0273b7f515f5fcc94fc6ebe9dd8cf645dfdc Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Tue, 2 Apr 2024 14:32:22 -0700 Subject: [PATCH 01/16] Add type annotation for CollectionTablesTuple --- .../lsst/daf/butler/registry/collections/_base.py | 13 +++++++++---- .../lsst/daf/butler/registry/collections/nameKey.py | 4 +++- .../daf/butler/registry/collections/synthIntKey.py | 4 +++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 4a41016174..0eb9d79e19 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -32,9 +32,8 @@ import itertools from abc import abstractmethod -from collections import namedtuple from collections.abc import Iterable, Iterator, Set -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar, cast import sqlalchemy @@ -77,7 +76,13 @@ def _makeCollectionForeignKey( return ddl.ForeignKeySpec("collection", source=(sourceColumnName,), target=(collectionIdName,), **kwargs) -CollectionTablesTuple = namedtuple("CollectionTablesTuple", ["collection", "run", "collection_chain"]) +_T = TypeVar("_T") + + +class CollectionTablesTuple(NamedTuple, Generic[_T]): + collection: _T + run: _T + collection_chain: _T def makeRunTableSpec( @@ -188,7 +193,7 @@ class DefaultCollectionManager(CollectionManager[K]): def __init__( self, db: Database, - tables: CollectionTablesTuple, + tables: CollectionTablesTuple[sqlalchemy.Table], collectionIdName: str, *, caching_context: CachingContext, diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index d000698373..ee354aef74 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -63,7 +63,9 @@ _LOG = logging.getLogger(__name__) -def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple: +def _makeTableSpecs( + TimespanReprClass: type[TimespanDatabaseRepresentation], +) -> CollectionTablesTuple[ddl.TableSpec]: return CollectionTablesTuple( collection=ddl.TableSpec( fields=[ diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index 52d79491e3..bac98c2440 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -63,7 +63,9 @@ _LOG = logging.getLogger(__name__) -def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> CollectionTablesTuple: +def _makeTableSpecs( + TimespanReprClass: type[TimespanDatabaseRepresentation], +) -> CollectionTablesTuple[ddl.TableSpec]: return CollectionTablesTuple( collection=ddl.TableSpec( fields=[ From 4ae068898f86eb65a4000b8998c1e949fa6fba2c Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Tue, 2 Apr 2024 14:33:44 -0700 Subject: [PATCH 02/16] Add isInTransaction() to DB This transaction check is done a few places, and I need it for a sanity check outside Database in an upcoming commit. --- .../daf/butler/registry/interfaces/_database.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/lsst/daf/butler/registry/interfaces/_database.py b/python/lsst/daf/butler/registry/interfaces/_database.py index 8faea919c5..d29531b037 100644 --- a/python/lsst/daf/butler/registry/interfaces/_database.py +++ b/python/lsst/daf/butler/registry/interfaces/_database.py @@ -632,7 +632,7 @@ def _transaction( - ``connection`` (`sqlalchemy.engine.Connection`): the connection. """ with self._session() as connection: - already_in_transaction = connection.in_transaction() + already_in_transaction = self.isInTransaction() assert not (interrupting and already_in_transaction), ( "Logic error in transaction nesting: an operation that would " "interrupt the active transaction context has been requested." @@ -794,6 +794,13 @@ def isWriteable(self) -> bool: """Return `True` if this database can be modified by this client.""" raise NotImplementedError() + def isInTransaction(self) -> bool: + """Return `True` if there is currently a database connection open with + an active transaction; `False` otherwise. + """ + session = self._session_connection + return session is not None and session.in_transaction() + @abstractmethod def __str__(self) -> str: """Return a human-readable identifier for this `Database`, including @@ -1120,9 +1127,7 @@ def ensureTableExists(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema """ # TODO: if _engine is used to make a table then it uses separate # connection and should not interfere with current transaction - assert ( - self._session_connection is None or not self._session_connection.in_transaction() - ), "Table creation interrupts transactions." + assert not self.isInTransaction(), "Table creation interrupts transactions." assert self._metadata is not None, "Static tables must be declared before dynamic tables." table = self.getExistingTable(name, spec) if table is not None: From ffefc7128574ced794378aef9d6a57d9eee9b32d Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 09:40:04 -0700 Subject: [PATCH 03/16] Factor out collection chain insert function --- .../daf/butler/registry/collections/_base.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 0eb9d79e19..2c43e49715 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -426,22 +426,29 @@ def update_chain( record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True) ) - rows = [] - position = itertools.count() - names = [] - for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): - rows.append( - { - "parent": chain.key, - "child": child.key, - "position": next(position), - } - ) - names.append(child.name) + child_records = self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False) + names = [child.name for child in child_records] with self._db.transaction(): self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key}) - self._db.insert(self._tables.collection_chain, *rows) + self._insert_collection_chain_rows(chain.key, 0, [child.key for child in child_records]) record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names)) self._addCachedRecord(record) return record + + def _insert_collection_chain_rows( + self, + parent_key: K, + starting_position: int, + child_keys: Iterable[K], + ) -> None: + position = itertools.count(starting_position) + rows = [ + { + "parent": parent_key, + "child": child, + "position": next(position), + } + for child in child_keys + ] + self._db.insert(self._tables.collection_chain, *rows) From def8d69753f7b5f5156bbc3d0dcf2b338e151332 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 10:31:52 -0700 Subject: [PATCH 04/16] Implement chained collection prepend --- python/lsst/daf/butler/_butler.py | 30 ++++++ python/lsst/daf/butler/direct_butler.py | 8 ++ .../daf/butler/registry/collections/_base.py | 92 +++++++++++++++++++ .../butler/registry/collections/nameKey.py | 6 ++ .../registry/collections/synthIntKey.py | 6 ++ .../registry/interfaces/_collections.py | 29 ++++++ .../butler/remote_butler/_remote_butler.py | 5 + python/lsst/daf/butler/tests/hybrid_butler.py | 5 + tests/test_butler.py | 28 ++++++ 9 files changed, 209 insertions(+) diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index 7c232b2d43..60a339e782 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -1735,3 +1735,33 @@ def _clone( ``inferDefaults``, and default data ID. """ raise NotImplementedError() + + @abstractmethod + def prepend_collection_chain( + self, parent_collection_name: str, child_collection_names: str | Iterable[str] + ) -> None: + """Add children to the beginning of a CHAINED collection. + + Parameters + ---------- + parent_collection_name : `str` + The name of a CHAINED collection to which we will add new children. + child_collection_names : `Iterable` [ `str ` ] | `str` + A child collection name or list of child collection names to be + added to the parent. + + Raises + ------ + MissingCollectionError + If any of the specified collections do not exist. + CollectionTypeError + If the parent collection is not a CHAINED collection. + + Notes + ----- + If this function is called within a call to ``Butler.transaction``, it + will hold a lock that prevents other processes from modifying the + parent collection until the end of the transaction. Keep these + transactions short. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/direct_butler.py b/python/lsst/daf/butler/direct_butler.py index 75370933e0..acac9c7bdf 100644 --- a/python/lsst/daf/butler/direct_butler.py +++ b/python/lsst/daf/butler/direct_butler.py @@ -48,6 +48,7 @@ from lsst.resources import ResourcePath, ResourcePathExpression from lsst.utils.introspection import get_class_of +from lsst.utils.iteration import ensure_iterable from lsst.utils.logging import VERBOSE, getLogger from sqlalchemy.exc import IntegrityError @@ -2141,6 +2142,13 @@ def _preload_cache(self) -> None: """Immediately load caches that are used for common operations.""" self._registry.preload_cache() + def prepend_collection_chain( + self, parent_collection_name: str, child_collection_names: str | Iterable[str] + ) -> None: + return self._registry._managers.collections.prepend_collection_chain( + parent_collection_name, list(ensure_iterable(child_collection_names)) + ) + _config: ButlerConfig """Configuration for this Butler instance.""" diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 2c43e49715..68f5400f71 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -40,6 +40,7 @@ from ..._exceptions import MissingCollectionError from ...timespan_database_representation import TimespanDatabaseRepresentation from .._collection_type import CollectionType +from .._exceptions import CollectionTypeError from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple from ..wildcards import CollectionWildcard @@ -452,3 +453,94 @@ def _insert_collection_chain_rows( for child in child_keys ] self._db.insert(self._tables.collection_chain, *rows) + + def prepend_collection_chain( + self, parent_collection_name: str, child_collection_names: list[str] + ) -> None: + child_records = self.resolve_wildcard( + CollectionWildcard.from_names(child_collection_names), flatten_chains=False + ) + child_keys = [child.key for child in child_records] + assert len(child_keys) == len(child_collection_names) + + with self._db.transaction(): + parent_key = self._find_and_lock_collection_chain(parent_collection_name) + starting_position = self._find_lowest_position_in_collection_chain(parent_key) - len(child_keys) + self._insert_collection_chain_rows(parent_key, starting_position, child_keys) + + def _find_lowest_position_in_collection_chain(self, chain_key: K) -> int: + """Return the lowest-numbered position in a collection chain, or 0 if + the chain is empty. + """ + table = self._tables.collection_chain + query = sqlalchemy.select(sqlalchemy.func.min(table.c.position)).where(table.c.parent == chain_key) + with self._db.query(query) as cursor: + lowest_existing_position = cursor.scalar() + + if lowest_existing_position is None: + return 0 + + return lowest_existing_position + + def _find_and_lock_collection_chain(self, collection_name: str) -> K: + """ + Take a row lock on the specified collection's row in the collections + table, and return the collection's primary key. + + This lock is used to synchronize updates to collection chains. + + The locking strategy requires cooperation from everything modifying the + collection chain table -- all operations that modify collection chains + must obtain this lock first. The database will NOT automatically + prevent modification of tables based on this lock. The only guarantee + is that only one caller will be allowed to hold this lock for a given + collection at a time. Concurrent calls will block until the caller + holding the lock has completed its transaction. + + Parameters + ---------- + collection_name : `str` + Name of the collection whose chain is being modified. + + Returns + ------- + id : ``K`` + The primary key for the given collection. + + Raises + ------ + MissingCollectionError + If the specified collection is not in the database table. + CollectionTypeError + If the specified collection is not a chained collection. + """ + assert self._db.isInTransaction(), ( + "Row locks are only held until the end of the current transaction," + " so it makes no sense to take a lock outside a transaction." + ) + assert self._db.isWriteable(), "Collection row locks are only useful for write operations." + + query = self._select_pkey_by_name(collection_name).with_for_update() + with self._db.query(query) as cursor: + rows = cursor.all() + + if len(rows) == 0: + raise MissingCollectionError( + f"Parent collection {collection_name} not found when updating collection chain." + ) + assert len(rows) == 1, "There should only be one entry for each collection in collection table." + r = rows[0]._mapping + if r["type"] != CollectionType.CHAINED: + raise CollectionTypeError(f"Parent collection {collection_name} is not a chained collection.") + return r["key"] + + @abstractmethod + def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: + """Return a SQLAlchemy select statement that will return columns from + the one row in the ``collection` table matching the given name. The + select statement includes two columns: + + - ``key`` : the primary key for the collection + - ``type`` : the collection type + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index ee354aef74..7b2103dda5 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -285,6 +285,12 @@ def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> li return records + def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: + table = self._tables.collection + return sqlalchemy.select(table.c.name.label("key"), table.c.type).where( + table.c.name == collection_name + ) + @classmethod def currentVersions(cls) -> list[VersionTuple]: # Docstring inherited from VersionedExtension. diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index bac98c2440..e329f40706 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -302,6 +302,12 @@ def _rows_to_chains( return records + def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: + table = self._tables.collection + return sqlalchemy.select(table.c.collection_id.label("key"), table.c.type).where( + table.c.name == collection_name + ) + @classmethod def currentVersions(cls) -> list[VersionTuple]: # Docstring inherited from VersionedExtension. diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index cef7b9741f..0a18f3f687 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -621,3 +621,32 @@ def update_chain( `~CollectionType.CHAINED` collections in ``children`` first. """ raise NotImplementedError() + + def prepend_collection_chain( + self, parent_collection_name: str, child_collection_names: list[str] + ) -> None: + """Add children to the beginning of a CHAINED collection. + + Parameters + ---------- + parent_collection_name : `str` + The name of a CHAINED collection to which we will add new children. + child_collection_names : `list` [ `str ` ] + A child collection name or list of child collection names to be + added to the parent. + + Raises + ------ + MissingCollectionError + If any of the specified collections do not exist. + CollectionTypeError + If the parent collection is not a CHAINED collection. + + Notes + ----- + If this function is called within a call to ``Butler.transaction``, it + will hold a lock that prevents other processes from modifying the + parent collection until the end of the transaction. Keep these + transactions short. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py index 911663bee8..79d9d56b10 100644 --- a/python/lsst/daf/butler/remote_butler/_remote_butler.py +++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py @@ -570,6 +570,11 @@ def collections(self) -> Sequence[str]: # Docstring inherited. return self._registry_defaults.collections + def prepend_collection_chain( + self, parent_collection_name: str, child_collection_names: str | Iterable[str] + ) -> None: + raise NotImplementedError() + @property def run(self) -> str | None: # Docstring inherited. diff --git a/python/lsst/daf/butler/tests/hybrid_butler.py b/python/lsst/daf/butler/tests/hybrid_butler.py index c0959ffcaa..8345f3fe44 100644 --- a/python/lsst/daf/butler/tests/hybrid_butler.py +++ b/python/lsst/daf/butler/tests/hybrid_butler.py @@ -452,3 +452,8 @@ def _extract_all_dimension_records_from_data_ids( return self._direct_butler._extract_all_dimension_records_from_data_ids( source_butler, data_ids, allowed_elements ) + + def prepend_collection_chain( + self, parent_collection_name: str, child_collection_names: str | Iterable[str] + ) -> None: + return self._direct_butler.prepend_collection_chain(parent_collection_name, child_collection_names) diff --git a/tests/test_butler.py b/tests/test_butler.py index 8d95f6adeb..970a63d491 100644 --- a/tests/test_butler.py +++ b/tests/test_butler.py @@ -1387,6 +1387,34 @@ def testGetDatasetCollectionCaching(self): get_ref = reader_butler.get_dataset(put_ref.id) self.assertEqual(get_ref.id, put_ref.id) + def testCollectionChainPrepend(self): + butler = self.create_empty_butler(writeable=True) + butler.registry.registerCollection("chain", CollectionType.CHAINED) + runs = ["a", "b", "c", "d"] + for run in runs: + butler.registry.registerCollection(run) + + def check_chain(expected: list[str]) -> None: + children = butler.registry.getCollectionChain("chain") + self.assertEqual(expected, list(children)) + + butler.prepend_collection_chain("chain", ["c", "b"]) + check_chain(["c", "b"]) + butler.prepend_collection_chain("chain", ["a"]) + check_chain(["a", "c", "b"]) + butler.prepend_collection_chain("chain", []) + check_chain(["a", "c", "b"]) + + # Missing parent collection + with self.assertRaises(MissingCollectionError): + butler.prepend_collection_chain("chain2", []) + # Missing child collection + with self.assertRaises(MissingCollectionError): + butler.prepend_collection_chain("chain", ["doesnotexist"]) + # Forbid operations on non-chained collections + with self.assertRaises(CollectionTypeError): + butler.prepend_collection_chain("d", ["a"]) + class FileDatastoreButlerTests(ButlerTests): """Common tests and specialization of ButlerTests for butlers backed From 601d5bcdcd62606370c007312e1e5bc25cc64a57 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 11:02:59 -0700 Subject: [PATCH 05/16] Handle cache invalidation for chain prepend --- python/lsst/daf/butler/registry/collections/_base.py | 11 +++++++++++ tests/test_butler.py | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 68f5400f71..57a2e5ac29 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -468,6 +468,8 @@ def prepend_collection_chain( starting_position = self._find_lowest_position_in_collection_chain(parent_key) - len(child_keys) self._insert_collection_chain_rows(parent_key, starting_position, child_keys) + self._refresh_cache_for_key(parent_key) + def _find_lowest_position_in_collection_chain(self, chain_key: K) -> int: """Return the lowest-numbered position in a collection chain, or 0 if the chain is empty. @@ -544,3 +546,12 @@ def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: - ``type`` : the collection type """ raise NotImplementedError() + + def _refresh_cache_for_key(self, key: K) -> None: + """Refresh the data in the cache for a single collection.""" + cache = self._caching_context.collection_records + if cache is not None: + records = self._fetch_by_key([key]) + if records: + assert len(records) == 1 + cache.add(records[0]) diff --git a/tests/test_butler.py b/tests/test_butler.py index 970a63d491..db03b0030f 100644 --- a/tests/test_butler.py +++ b/tests/test_butler.py @@ -1389,6 +1389,14 @@ def testGetDatasetCollectionCaching(self): def testCollectionChainPrepend(self): butler = self.create_empty_butler(writeable=True) + self._testCollectionChainPrepend(butler) + + def testCollectionChainPrependCached(self): + butler = self.create_empty_butler(writeable=True) + with butler._caching_context(): + self._testCollectionChainPrepend(butler) + + def _testCollectionChainPrepend(self, butler: Butler) -> None: butler.registry.registerCollection("chain", CollectionType.CHAINED) runs = ["a", "b", "c", "d"] for run in runs: From dc1ed9f42666bba98a1040190df0ec3769f8cdb7 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 11:24:30 -0700 Subject: [PATCH 06/16] Factor out collection cycle check function --- .../daf/butler/registry/collections/_base.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 57a2e5ac29..805324d4cd 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -413,18 +413,14 @@ def update_chain( self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False ) -> ChainedCollectionRecord[K]: # Docstring inherited from CollectionManager. - children_as_wildcard = CollectionWildcard.from_names(children) - for record in self.resolve_wildcard( - children_as_wildcard, - flatten_chains=True, - include_chains=True, - collection_types={CollectionType.CHAINED}, - ): - if record == chain: - raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.") + children = list(children) + self._sanity_check_collection_cycles(chain.name, children) if flatten: children = tuple( - record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True) + record.name + for record in self.resolve_wildcard( + CollectionWildcard.from_names(children), flatten_chains=True + ) ) child_records = self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False) @@ -437,6 +433,23 @@ def update_chain( self._addCachedRecord(record) return record + def _sanity_check_collection_cycles(self, parent_name: str, child_names: list[str]) -> None: + """Raise an exception if any of the collections in the ``child_names`` + list have ``parent_name`` as a child, creating a collection cycle. + + This is only a sanity check, and does not guarantee that no collection + cycles are possible. Concurrent updates might allow collection cycles + to be inserted. + """ + for record in self.resolve_wildcard( + CollectionWildcard.from_names(child_names), + flatten_chains=True, + include_chains=True, + collection_types={CollectionType.CHAINED}, + ): + if record.name == parent_name: + raise ValueError(f"Cycle in collection chaining when defining '{parent_name}'.") + def _insert_collection_chain_rows( self, parent_key: K, From c7ca6975b74c4ed2003ef3b3aa582ad1a84607e4 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 11:47:50 -0700 Subject: [PATCH 07/16] Check for collection cycles in chain prepend --- python/lsst/daf/butler/_butler.py | 2 ++ python/lsst/daf/butler/_exceptions.py | 10 ++++++++++ .../daf/butler/registry/collections/_base.py | 17 ++++++++++++----- .../butler/registry/interfaces/_collections.py | 2 ++ python/lsst/daf/butler/registry/sql_registry.py | 2 +- tests/test_butler.py | 10 +++++++++- 6 files changed, 36 insertions(+), 7 deletions(-) diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index 60a339e782..de92a90d0a 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -1756,6 +1756,8 @@ def prepend_collection_chain( If any of the specified collections do not exist. CollectionTypeError If the parent collection is not a CHAINED collection. + CollectionCycleError + If this operation would create a collection cycle. Notes ----- diff --git a/python/lsst/daf/butler/_exceptions.py b/python/lsst/daf/butler/_exceptions.py index 9c5f57757e..1a0007cc86 100644 --- a/python/lsst/daf/butler/_exceptions.py +++ b/python/lsst/daf/butler/_exceptions.py @@ -28,6 +28,7 @@ """Specialized Butler exceptions.""" __all__ = ( "CalibrationLookupError", + "CollectionCycleError", "DatasetNotFoundError", "DimensionNameError", "ButlerUserError", @@ -79,6 +80,14 @@ class CalibrationLookupError(LookupError, ButlerUserError): error_type = "calibration_lookup" +class CollectionCycleError(ValueError, ButlerUserError): + """Raised when an operation would cause a chained collection to be a child + of itself. + """ + + error_type = "collection_cycle" + + class DatasetNotFoundError(LookupError, ButlerUserError): """The requested dataset could not be found.""" @@ -158,6 +167,7 @@ class UnknownButlerUserError(ButlerUserError): _USER_ERROR_TYPES: tuple[type[ButlerUserError], ...] = ( CalibrationLookupError, + CollectionCycleError, DimensionNameError, DimensionValueError, DatasetNotFoundError, diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 805324d4cd..2147d3510a 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -37,7 +37,7 @@ import sqlalchemy -from ..._exceptions import MissingCollectionError +from ..._exceptions import CollectionCycleError, MissingCollectionError from ...timespan_database_representation import TimespanDatabaseRepresentation from .._collection_type import CollectionType from .._exceptions import CollectionTypeError @@ -415,6 +415,7 @@ def update_chain( # Docstring inherited from CollectionManager. children = list(children) self._sanity_check_collection_cycles(chain.name, children) + if flatten: children = tuple( record.name @@ -433,7 +434,9 @@ def update_chain( self._addCachedRecord(record) return record - def _sanity_check_collection_cycles(self, parent_name: str, child_names: list[str]) -> None: + def _sanity_check_collection_cycles( + self, parent_collection_name: str, child_collection_names: list[str] + ) -> None: """Raise an exception if any of the collections in the ``child_names`` list have ``parent_name`` as a child, creating a collection cycle. @@ -442,13 +445,15 @@ def _sanity_check_collection_cycles(self, parent_name: str, child_names: list[st to be inserted. """ for record in self.resolve_wildcard( - CollectionWildcard.from_names(child_names), + CollectionWildcard.from_names(child_collection_names), flatten_chains=True, include_chains=True, collection_types={CollectionType.CHAINED}, ): - if record.name == parent_name: - raise ValueError(f"Cycle in collection chaining when defining '{parent_name}'.") + if record.name == parent_collection_name: + raise CollectionCycleError( + f"Cycle in collection chaining when defining '{parent_collection_name}'." + ) def _insert_collection_chain_rows( self, @@ -470,6 +475,8 @@ def _insert_collection_chain_rows( def prepend_collection_chain( self, parent_collection_name: str, child_collection_names: list[str] ) -> None: + self._sanity_check_collection_cycles(parent_collection_name, child_collection_names) + child_records = self.resolve_wildcard( CollectionWildcard.from_names(child_collection_names), flatten_chains=False ) diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index 0a18f3f687..ddd647f1e4 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -641,6 +641,8 @@ def prepend_collection_chain( If any of the specified collections do not exist. CollectionTypeError If the parent collection is not a CHAINED collection. + CollectionCycleError + If this operation would create a collection cycle. Notes ----- diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 7e070027a6..4e92001386 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -621,7 +621,7 @@ def setCollectionChain(self, parent: str, children: Any, *, flatten: bool = Fals lsst.daf.butler.registry.CollectionTypeError Raised if ``parent`` does not correspond to a `~CollectionType.CHAINED` collection. - ValueError + CollectionCycleError Raised if the given collections contains a cycle. """ record = self._managers.collections.find(parent) diff --git a/tests/test_butler.py b/tests/test_butler.py index db03b0030f..4bfe5e2022 100644 --- a/tests/test_butler.py +++ b/tests/test_butler.py @@ -82,6 +82,7 @@ def mock_aws(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-untyped-def] Butler, ButlerConfig, ButlerRepoIndex, + CollectionCycleError, CollectionType, Config, DataCoordinate, @@ -1398,6 +1399,7 @@ def testCollectionChainPrependCached(self): def _testCollectionChainPrepend(self, butler: Butler) -> None: butler.registry.registerCollection("chain", CollectionType.CHAINED) + runs = ["a", "b", "c", "d"] for run in runs: butler.registry.registerCollection(run) @@ -1415,7 +1417,7 @@ def check_chain(expected: list[str]) -> None: # Missing parent collection with self.assertRaises(MissingCollectionError): - butler.prepend_collection_chain("chain2", []) + butler.prepend_collection_chain("doesnotexist", []) # Missing child collection with self.assertRaises(MissingCollectionError): butler.prepend_collection_chain("chain", ["doesnotexist"]) @@ -1423,6 +1425,12 @@ def check_chain(expected: list[str]) -> None: with self.assertRaises(CollectionTypeError): butler.prepend_collection_chain("d", ["a"]) + # Prevent collection cycles + butler.registry.registerCollection("chain2", CollectionType.CHAINED) + butler.prepend_collection_chain("chain2", "chain") + with self.assertRaises(CollectionCycleError): + butler.prepend_collection_chain("chain", "chain2") + class FileDatastoreButlerTests(ButlerTests): """Common tests and specialization of ButlerTests for butlers backed From 910f8169391fc4287172a83abef5290561baa503 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 13:38:12 -0700 Subject: [PATCH 08/16] Add test for concurrent collection chain prepend --- .../daf/butler/registry/collections/_base.py | 1 + .../registry/interfaces/_collections.py | 6 ++ .../daf/butler/registry/tests/_registry.py | 63 +++++++++++++++++-- tests/test_remote_butler.py | 5 ++ 4 files changed, 71 insertions(+), 4 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 2147d3510a..5cc7469388 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -486,6 +486,7 @@ def prepend_collection_chain( with self._db.transaction(): parent_key = self._find_and_lock_collection_chain(parent_collection_name) starting_position = self._find_lowest_position_in_collection_chain(parent_key) - len(child_keys) + self._block_for_concurrency_test() self._insert_collection_chain_rows(parent_key, starting_position, child_keys) self._refresh_cache_for_key(parent_key) diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index ddd647f1e4..fb259f42ff 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -622,6 +622,7 @@ def update_chain( """ raise NotImplementedError() + @abstractmethod def prepend_collection_chain( self, parent_collection_name: str, child_collection_names: list[str] ) -> None: @@ -652,3 +653,8 @@ def prepend_collection_chain( transactions short. """ raise NotImplementedError() + + def _block_for_concurrency_test(self) -> None: + """No-op normally. Provide a place for unit tests to hook in and + verify locking behavior. + """ diff --git a/python/lsst/daf/butler/registry/tests/_registry.py b/python/lsst/daf/butler/registry/tests/_registry.py index afea20457c..0b0545338f 100644 --- a/python/lsst/daf/butler/registry/tests/_registry.py +++ b/python/lsst/daf/butler/registry/tests/_registry.py @@ -34,13 +34,14 @@ import itertools import os import re +import time import unittest import uuid from abc import ABC, abstractmethod from collections import defaultdict, namedtuple from collections.abc import Iterator from datetime import timedelta -from typing import TYPE_CHECKING +from threading import Barrier, Thread import astropy.time import sqlalchemy @@ -77,9 +78,7 @@ ) from .._registry import Registry from ..interfaces import ButlerAttributeExistsError - -if TYPE_CHECKING: - from ..sql_registry import SqlRegistry +from ..sql_registry import SqlRegistry class RegistryTests(ABC): @@ -849,6 +848,62 @@ def testCollectionChainFlatten(self): registry.setCollectionChain("outer", ["inner"], flatten=True) self.assertEqual(list(registry.getCollectionChain("outer")), ["innermost"]) + def testCollectionChainConcurrency(self): + """Verify that locking via database row locks is working as + expected. + """ + registry1 = self.makeRegistry() + assert isinstance(registry1, SqlRegistry) + registry2 = self.makeRegistry(share_repo_with=registry1) + if registry2 is None: + # This will happen for in-memory SQL databases. + raise unittest.SkipTest("Testing concurrency requires two connections to the same DB.") + + registry1.registerCollection("chain", CollectionType.CHAINED) + for collection in ["a", "b"]: + registry1.registerCollection(collection) + + # Cause registry1 to block at the worst possible moment -- after it has + # decided on positions for the new children in the collection chain, + # but before inserting them. + enter_barrier = Barrier(2, timeout=60) + exit_barrier = Barrier(2, timeout=60) + + def wait_for_barrier(): + enter_barrier.wait() + exit_barrier.wait() + + registry1._managers.collections._block_for_concurrency_test = wait_for_barrier + + def thread1_func(): + registry1._managers.collections.prepend_collection_chain("chain", ["a"]) + + def thread2_func(): + registry2._managers.collections.prepend_collection_chain("chain", ["b"]) + + thread1 = Thread(target=thread1_func) + thread2 = Thread(target=thread2_func) + try: + thread1.start() + enter_barrier.wait() + + # At this point registry 1 has entered the critical section and is + # waiting for us to release it. Start the other thread. + thread2.start() + # thread2 should block inside a database call, but we have no way + # to detect when it is in this state. + time.sleep(0.100) + + # Let the threads run to completion. + exit_barrier.wait() + finally: + thread1.join() + thread2.join() + + # Thread1 should have finished first, inserting "a". Thread2 should + # have finished second, prepending "b". + self.assertEqual(("b", "a"), registry1.getCollectionChain("chain")) + def testBasicTransaction(self): """Test that all operations within a single transaction block are rolled back if an exception propagates out of the block. diff --git a/tests/test_remote_butler.py b/tests/test_remote_butler.py index e0cf8cb3e8..f6abc6edb3 100644 --- a/tests/test_remote_butler.py +++ b/tests/test_remote_butler.py @@ -161,6 +161,11 @@ def testOpaque(self): # the client side. pass + def testCollectionChainConcurrency(self): + # This tests an implementation detail that requires access to the + # collection manager object. + pass + def testAttributeManager(self): # Tests a non-public API that isn't relevant on the client side. pass From e73b1cf34fa9fb1f7ff2d295ff582722ac3539a8 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 13:42:27 -0700 Subject: [PATCH 09/16] Make CollectionTypeError a top-level exception --- python/lsst/daf/butler/_exceptions.py | 8 ++++++++ python/lsst/daf/butler/registry/__init__.py | 7 ++++++- python/lsst/daf/butler/registry/_exceptions.py | 5 ----- python/lsst/daf/butler/registry/collections/_base.py | 3 +-- .../daf/butler/registry/datasets/byDimensions/_storage.py | 3 ++- python/lsst/daf/butler/registry/tests/_registry.py | 3 +-- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/python/lsst/daf/butler/_exceptions.py b/python/lsst/daf/butler/_exceptions.py index 1a0007cc86..cc988d1779 100644 --- a/python/lsst/daf/butler/_exceptions.py +++ b/python/lsst/daf/butler/_exceptions.py @@ -29,6 +29,7 @@ __all__ = ( "CalibrationLookupError", "CollectionCycleError", + "CollectionTypeError", "DatasetNotFoundError", "DimensionNameError", "ButlerUserError", @@ -88,6 +89,12 @@ class CollectionCycleError(ValueError, ButlerUserError): error_type = "collection_cycle" +class CollectionTypeError(CollectionError, ButlerUserError): + """Exception raised when type of a collection is incorrect.""" + + error_type = "collection_type" + + class DatasetNotFoundError(LookupError, ButlerUserError): """The requested dataset could not be found.""" @@ -168,6 +175,7 @@ class UnknownButlerUserError(ButlerUserError): _USER_ERROR_TYPES: tuple[type[ButlerUserError], ...] = ( CalibrationLookupError, CollectionCycleError, + CollectionTypeError, DimensionNameError, DimensionValueError, DatasetNotFoundError, diff --git a/python/lsst/daf/butler/registry/__init__.py b/python/lsst/daf/butler/registry/__init__.py index 573ba19772..b078d4a260 100644 --- a/python/lsst/daf/butler/registry/__init__.py +++ b/python/lsst/daf/butler/registry/__init__.py @@ -27,7 +27,12 @@ # Re-export some top-level exception types for backwards compatibility -- these # used to be part of registry. -from .._exceptions import DimensionNameError, MissingCollectionError, MissingDatasetTypeError +from .._exceptions import ( + CollectionTypeError, + DimensionNameError, + MissingCollectionError, + MissingDatasetTypeError, +) from .._exceptions_legacy import CollectionError, DataIdError, DatasetTypeError, RegistryError # Registry imports. diff --git a/python/lsst/daf/butler/registry/_exceptions.py b/python/lsst/daf/butler/registry/_exceptions.py index 815815dc73..82604ef170 100644 --- a/python/lsst/daf/butler/registry/_exceptions.py +++ b/python/lsst/daf/butler/registry/_exceptions.py @@ -29,7 +29,6 @@ __all__ = ( "ArgumentError", "CollectionExpressionError", - "CollectionTypeError", "ConflictingDefinitionError", "DataIdValueError", "DatasetTypeExpressionError", @@ -64,10 +63,6 @@ class InconsistentDataIdError(DataIdError): """ -class CollectionTypeError(CollectionError): - """Exception raised when type of a collection is incorrect.""" - - class CollectionExpressionError(CollectionError): """Exception raised for an incorrect collection expression.""" diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 5cc7469388..eb942f76a9 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -37,10 +37,9 @@ import sqlalchemy -from ..._exceptions import CollectionCycleError, MissingCollectionError +from ..._exceptions import CollectionCycleError, CollectionTypeError, MissingCollectionError from ...timespan_database_representation import TimespanDatabaseRepresentation from .._collection_type import CollectionType -from .._exceptions import CollectionTypeError from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple from ..wildcards import CollectionWildcard diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py index c67d0b6fb8..792fb9bcce 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py @@ -44,11 +44,12 @@ from ...._column_type_info import LogicalColumn from ...._dataset_ref import DatasetId, DatasetIdFactory, DatasetIdGenEnum, DatasetRef from ...._dataset_type import DatasetType +from ...._exceptions import CollectionTypeError from ...._timespan import Timespan from ....dimensions import DataCoordinate from ..._collection_summary import CollectionSummary from ..._collection_type import CollectionType -from ..._exceptions import CollectionTypeError, ConflictingDefinitionError +from ..._exceptions import ConflictingDefinitionError from ...interfaces import DatasetRecordStorage from ...queries import SqlQueryContext from .tables import makeTagTableSpec diff --git a/python/lsst/daf/butler/registry/tests/_registry.py b/python/lsst/daf/butler/registry/tests/_registry.py index 0b0545338f..b28b9d6c5d 100644 --- a/python/lsst/daf/butler/registry/tests/_registry.py +++ b/python/lsst/daf/butler/registry/tests/_registry.py @@ -57,7 +57,7 @@ from ..._dataset_association import DatasetAssociation from ..._dataset_ref import DatasetIdFactory, DatasetIdGenEnum, DatasetRef from ..._dataset_type import DatasetType -from ..._exceptions import MissingCollectionError, MissingDatasetTypeError +from ..._exceptions import CollectionTypeError, MissingCollectionError, MissingDatasetTypeError from ..._exceptions_legacy import DatasetTypeError from ..._storage_class import StorageClass from ..._timespan import Timespan @@ -68,7 +68,6 @@ from .._exceptions import ( ArgumentError, CollectionError, - CollectionTypeError, ConflictingDefinitionError, DataIdValueError, DatasetTypeExpressionError, From 59b42c644e303b5e327ab8e904c171aa3e575817 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 14:45:35 -0700 Subject: [PATCH 10/16] Make setCollectionChain more concurrency-safe Previously, setCollectionChain would sometimes throw unique index violation exceptions when there were concurrent calls to setCollectionChain. It now uses the same locking as prepend_collection_chain, so last write wins instead of throwing an exception. This also prevents potentially surprising interactions between prepend_collection_chain and setCollectionChain. --- .../daf/butler/registry/collections/_base.py | 2 + .../lsst/daf/butler/registry/sql_registry.py | 7 ++ .../daf/butler/registry/tests/_registry.py | 72 ++++++++++++++----- tests/test_remote_butler.py | 7 +- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index eb942f76a9..3a5bd3b9b3 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -426,7 +426,9 @@ def update_chain( child_records = self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False) names = [child.name for child in child_records] with self._db.transaction(): + self._find_and_lock_collection_chain(chain.name) self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key}) + self._block_for_concurrency_test() self._insert_collection_chain_rows(chain.key, 0, [child.key for child in child_records]) record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names)) diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 4e92001386..766288887b 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -623,6 +623,13 @@ def setCollectionChain(self, parent: str, children: Any, *, flatten: bool = Fals `~CollectionType.CHAINED` collection. CollectionCycleError Raised if the given collections contains a cycle. + + Notes + ----- + If this function is called within a call to ``Butler.transaction``, it + will hold a lock that prevents other processes from modifying the + parent collection until the end of the transaction. Keep these + transactions short. """ record = self._managers.collections.find(parent) if record.type is not CollectionType.CHAINED: diff --git a/python/lsst/daf/butler/registry/tests/_registry.py b/python/lsst/daf/butler/registry/tests/_registry.py index b28b9d6c5d..c371ab48d7 100644 --- a/python/lsst/daf/butler/registry/tests/_registry.py +++ b/python/lsst/daf/butler/registry/tests/_registry.py @@ -39,7 +39,7 @@ import uuid from abc import ABC, abstractmethod from collections import defaultdict, namedtuple -from collections.abc import Iterator +from collections.abc import Callable, Iterator from datetime import timedelta from threading import Barrier, Thread @@ -847,10 +847,59 @@ def testCollectionChainFlatten(self): registry.setCollectionChain("outer", ["inner"], flatten=True) self.assertEqual(list(registry.getCollectionChain("outer")), ["innermost"]) - def testCollectionChainConcurrency(self): + def testCollectionChainPrependConcurrency(self): """Verify that locking via database row locks is working as expected. """ + + def blocked_thread_func(registry: SqlRegistry): + # This call will become blocked after it has decided on positions + # for the new children in the collection chain, but before + # inserting them. + registry._managers.collections.prepend_collection_chain("chain", ["a"]) + + def unblocked_thread_func(registry: SqlRegistry): + registry._managers.collections.prepend_collection_chain("chain", ["b"]) + + registry = self._do_collection_concurrency_test(blocked_thread_func, unblocked_thread_func) + + # blocked_thread_func should have finished first, inserting "a". + # unblocked_thread_func should have finished second, prepending "b". + self.assertEqual(("b", "a"), registry.getCollectionChain("chain")) + + def testCollectionChainReplaceConcurrency(self): + """Verify that locking via database row locks is working as + expected. + """ + + def blocked_thread_func(registry: SqlRegistry): + # This call will become blocked after deleting children, but before + # inserting new ones. + registry.setCollectionChain("chain", ["a"]) + + def unblocked_thread_func(registry: SqlRegistry): + registry.setCollectionChain("chain", ["b"]) + + registry = self._do_collection_concurrency_test(blocked_thread_func, unblocked_thread_func) + + # blocked_thread_func should have finished first. + # unblocked_thread_func should have finished second, overwriting the + # chain with "b". + self.assertEqual(("b",), registry.getCollectionChain("chain")) + + def _do_collection_concurrency_test( + self, blocked_thread_func: Callable[[SqlRegistry]], unblocked_thread_func: Callable[[SqlRegistry]] + ) -> SqlRegistry: + # This function: + # 1. Sets up two registries pointing at the same database. + # 2. Start running 'blocked_thread_func' in a background thread, + # arranging for it to become blocked during a critical section in + # the collections manager. + # 3. Wait for 'blocked_thread_func' to reach the critical section + # 4. Start running 'unblocked_thread_func'. + # 5. Allow both functions to run to completion. + + # Set up two registries pointing to the same DB registry1 = self.makeRegistry() assert isinstance(registry1, SqlRegistry) registry2 = self.makeRegistry(share_repo_with=registry1) @@ -862,9 +911,8 @@ def testCollectionChainConcurrency(self): for collection in ["a", "b"]: registry1.registerCollection(collection) - # Cause registry1 to block at the worst possible moment -- after it has - # decided on positions for the new children in the collection chain, - # but before inserting them. + # Arrange for registry1 to block during its critical section, allowing + # us to detect this and control when it becomes unblocked. enter_barrier = Barrier(2, timeout=60) exit_barrier = Barrier(2, timeout=60) @@ -874,14 +922,8 @@ def wait_for_barrier(): registry1._managers.collections._block_for_concurrency_test = wait_for_barrier - def thread1_func(): - registry1._managers.collections.prepend_collection_chain("chain", ["a"]) - - def thread2_func(): - registry2._managers.collections.prepend_collection_chain("chain", ["b"]) - - thread1 = Thread(target=thread1_func) - thread2 = Thread(target=thread2_func) + thread1 = Thread(target=blocked_thread_func, args=[registry1]) + thread2 = Thread(target=unblocked_thread_func, args=[registry2]) try: thread1.start() enter_barrier.wait() @@ -899,9 +941,7 @@ def thread2_func(): thread1.join() thread2.join() - # Thread1 should have finished first, inserting "a". Thread2 should - # have finished second, prepending "b". - self.assertEqual(("b", "a"), registry1.getCollectionChain("chain")) + return registry1 def testBasicTransaction(self): """Test that all operations within a single transaction block are diff --git a/tests/test_remote_butler.py b/tests/test_remote_butler.py index f6abc6edb3..97a45cfa79 100644 --- a/tests/test_remote_butler.py +++ b/tests/test_remote_butler.py @@ -161,7 +161,12 @@ def testOpaque(self): # the client side. pass - def testCollectionChainConcurrency(self): + def testCollectionChainPrependConcurrency(self): + # This tests an implementation detail that requires access to the + # collection manager object. + pass + + def testCollectionChainReplaceConcurrency(self): # This tests an implementation detail that requires access to the # collection manager object. pass From ea71ce1e2883b1918eb1cd5d1abc4e07580b6f09 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 15:00:14 -0700 Subject: [PATCH 11/16] Use atomic prepend in the collection chain CLI --- .../lsst/daf/butler/script/collectionChain.py | 93 +++++++++++-------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/python/lsst/daf/butler/script/collectionChain.py b/python/lsst/daf/butler/script/collectionChain.py index a76aee5fcc..1abf31fc0a 100644 --- a/python/lsst/daf/butler/script/collectionChain.py +++ b/python/lsst/daf/butler/script/collectionChain.py @@ -31,6 +31,7 @@ from .._butler import Butler from ..registry import CollectionType, MissingCollectionError +from ..registry.wildcards import CollectionWildcard def collectionChain( @@ -91,49 +92,59 @@ def collectionChain( f"but collection '{parent}' is not known to this registry" ) from None - current = list(butler.registry.getCollectionChain(parent)) - - if mode == "redefine": - # Given children are what we want. - pass - elif mode == "prepend": - children = tuple(children) + tuple(current) - elif mode == "extend": - current.extend(children) - children = current - elif mode == "remove": - for child in children: - current.remove(child) - children = current - elif mode == "pop": - if children: - n_current = len(current) - - def convert_index(i: int) -> int: - """Convert negative index to positive.""" - if i >= 0: - return i - return n_current + i - - # For this mode the children should be integers. - # Convert negative integers to positive ones to allow - # sorting. - indices = [convert_index(int(child)) for child in children] - - # Reverse sort order so we can remove from the end first - indices = sorted(indices, reverse=True) + if flatten: + if mode not in ("redefine", "prepend", "extend"): + raise RuntimeError(f"'flatten' flag is not allowed for {mode}") + wildcard = CollectionWildcard.from_names(children) + children = butler.registry.queryCollections(wildcard, flattenChains=True) - else: - # Nothing specified, pop from the front of the chain. - indices = [0] + _modify_collection_chain(butler, mode, parent, children) - for i in indices: - current.pop(i) + return tuple(butler.registry.getCollectionChain(parent)) - children = current - else: - raise ValueError(f"Unrecognized update mode: '{mode}'") - butler.registry.setCollectionChain(parent, children, flatten=flatten) +def _modify_collection_chain(butler: Butler, mode: str, parent: str, children: Iterable[str]) -> None: + if mode == "prepend": + butler.prepend_collection_chain(parent, children) + elif mode == "redefine": + butler.registry.setCollectionChain(parent, children) + else: + current = list(butler.registry.getCollectionChain(parent)) + + if mode == "extend": + current.extend(children) + children = current + elif mode == "remove": + for child in children: + current.remove(child) + children = current + elif mode == "pop": + if children: + n_current = len(current) + + def convert_index(i: int) -> int: + """Convert negative index to positive.""" + if i >= 0: + return i + return n_current + i + + # For this mode the children should be integers. + # Convert negative integers to positive ones to allow + # sorting. + indices = [convert_index(int(child)) for child in children] + + # Reverse sort order so we can remove from the end first + indices = sorted(indices, reverse=True) + + else: + # Nothing specified, pop from the front of the chain. + indices = [0] + + for i in indices: + current.pop(i) + + children = current + else: + raise ValueError(f"Unrecognized update mode: '{mode}'") - return tuple(butler.registry.getCollectionChain(parent)) + butler.registry.setCollectionChain(parent, children) From 6339f2dc57274c3c53b6de60ffe4fa2cf39445e5 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 15:18:16 -0700 Subject: [PATCH 12/16] Add towncrier --- doc/changes/DM-43671.bugfix.md | 3 +++ doc/changes/DM-43671.feature.md | 1 + 2 files changed, 4 insertions(+) create mode 100644 doc/changes/DM-43671.bugfix.md create mode 100644 doc/changes/DM-43671.feature.md diff --git a/doc/changes/DM-43671.bugfix.md b/doc/changes/DM-43671.bugfix.md new file mode 100644 index 0000000000..4dcce25984 --- /dev/null +++ b/doc/changes/DM-43671.bugfix.md @@ -0,0 +1,3 @@ +The `flatten` flag for the `butler collection-chain` CLI command now works as documented: it only flattens the specified children instead of flattening the entire collection chain. + +`registry.setCollectionChain` will no longer throw unique constraint violation exceptions when there are concurrent calls to this function. Instead, all calls will succeed and the last write will win. As a side-effect of this change, if calls to `setCollectionChain` occur within an explicit call to `Butler.transaction`, other processes attempting to modify the same chain will block until the transaction completes. diff --git a/doc/changes/DM-43671.feature.md b/doc/changes/DM-43671.feature.md new file mode 100644 index 0000000000..7d2ad00205 --- /dev/null +++ b/doc/changes/DM-43671.feature.md @@ -0,0 +1 @@ +Added a new method `Butler.prepend_collection_chain`. This allows you to insert collections at the beginning of a chain. It is an "atomic" operation that can be safely used concurrently from multiple processes. From 11941a97a65388c040b2a3df04165dd7da3da5d9 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 3 Apr 2024 15:53:39 -0700 Subject: [PATCH 13/16] Use futures for chain concurrency test This gives better diagnostics if the test fails, because exceptions within the threads will be raised from the main thread when result() is called. --- .../daf/butler/registry/tests/_registry.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/python/lsst/daf/butler/registry/tests/_registry.py b/python/lsst/daf/butler/registry/tests/_registry.py index c371ab48d7..0c7ca4f39b 100644 --- a/python/lsst/daf/butler/registry/tests/_registry.py +++ b/python/lsst/daf/butler/registry/tests/_registry.py @@ -40,8 +40,9 @@ from abc import ABC, abstractmethod from collections import defaultdict, namedtuple from collections.abc import Callable, Iterator +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from threading import Barrier, Thread +from threading import Barrier import astropy.time import sqlalchemy @@ -922,24 +923,22 @@ def wait_for_barrier(): registry1._managers.collections._block_for_concurrency_test = wait_for_barrier - thread1 = Thread(target=blocked_thread_func, args=[registry1]) - thread2 = Thread(target=unblocked_thread_func, args=[registry2]) - try: - thread1.start() - enter_barrier.wait() - - # At this point registry 1 has entered the critical section and is - # waiting for us to release it. Start the other thread. - thread2.start() - # thread2 should block inside a database call, but we have no way - # to detect when it is in this state. - time.sleep(0.100) - - # Let the threads run to completion. - exit_barrier.wait() - finally: - thread1.join() - thread2.join() + with ThreadPoolExecutor(max_workers=1) as exec1: + with ThreadPoolExecutor(max_workers=1) as exec2: + future1 = exec1.submit(blocked_thread_func, registry1) + enter_barrier.wait() + + # At this point registry 1 has entered the critical section and + # is waiting for us to release it. Start the other thread. + future2 = exec2.submit(unblocked_thread_func, registry2) + # thread2 should block inside a database call, but we have no + # way to detect when it is in this state. + time.sleep(0.200) + + # Let the threads run to completion. + exit_barrier.wait() + future1.result() + future2.result() return registry1 From 023e418a0fbd41e0cc92f0a47dbf7cfef704e250 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Thu, 4 Apr 2024 10:01:23 -0700 Subject: [PATCH 14/16] Prevent prepend from creating duplicates Remove children from collection chains before prepending them, to ensure that there is only one copy of a child in the chain. --- python/lsst/daf/butler/_butler.py | 3 +++ .../daf/butler/registry/collections/_base.py | 13 +++++++++++-- tests/test_butler.py | 18 +++++++++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index de92a90d0a..0aa1f1da27 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -1742,6 +1742,9 @@ def prepend_collection_chain( ) -> None: """Add children to the beginning of a CHAINED collection. + If any of the children already existed in the chain, they will be moved + to the new position at the beginning of the chain. + Parameters ---------- parent_collection_name : `str` diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 3a5bd3b9b3..6f1fba1908 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -460,7 +460,7 @@ def _insert_collection_chain_rows( self, parent_key: K, starting_position: int, - child_keys: Iterable[K], + child_keys: list[K], ) -> None: position = itertools.count(starting_position) rows = [ @@ -473,6 +473,15 @@ def _insert_collection_chain_rows( ] self._db.insert(self._tables.collection_chain, *rows) + def _remove_collection_chain_rows( + self, + parent_key: K, + child_keys: list[K], + ) -> None: + table = self._tables.collection_chain + where = sqlalchemy.and_(table.c.parent == parent_key, table.c.child.in_(child_keys)) + self._db.deleteWhere(table, where) + def prepend_collection_chain( self, parent_collection_name: str, child_collection_names: list[str] ) -> None: @@ -482,10 +491,10 @@ def prepend_collection_chain( CollectionWildcard.from_names(child_collection_names), flatten_chains=False ) child_keys = [child.key for child in child_records] - assert len(child_keys) == len(child_collection_names) with self._db.transaction(): parent_key = self._find_and_lock_collection_chain(parent_collection_name) + self._remove_collection_chain_rows(parent_key, child_keys) starting_position = self._find_lowest_position_in_collection_chain(parent_key) - len(child_keys) self._block_for_concurrency_test() self._insert_collection_chain_rows(parent_key, starting_position, child_keys) diff --git a/tests/test_butler.py b/tests/test_butler.py index 4bfe5e2022..5b7f49c85e 100644 --- a/tests/test_butler.py +++ b/tests/test_butler.py @@ -1404,17 +1404,30 @@ def _testCollectionChainPrepend(self, butler: Butler) -> None: for run in runs: butler.registry.registerCollection(run) + butler.registry.registerCollection("staticchain", CollectionType.CHAINED) + butler.registry.setCollectionChain("staticchain", ["a", "b"]) + def check_chain(expected: list[str]) -> None: children = butler.registry.getCollectionChain("chain") self.assertEqual(expected, list(children)) - butler.prepend_collection_chain("chain", ["c", "b"]) + # Duplicates are removed from the list of children + butler.prepend_collection_chain("chain", ["c", "b", "c"]) check_chain(["c", "b"]) + + # Prepend goes on the front of existing chain butler.prepend_collection_chain("chain", ["a"]) check_chain(["a", "c", "b"]) + + # Empty prepend does nothing butler.prepend_collection_chain("chain", []) check_chain(["a", "c", "b"]) + # Prepending children that already exist in the chain removes them from + # their current position. + butler.prepend_collection_chain("chain", ["d", "b", "c"]) + check_chain(["d", "b", "c", "a"]) + # Missing parent collection with self.assertRaises(MissingCollectionError): butler.prepend_collection_chain("doesnotexist", []) @@ -1431,6 +1444,9 @@ def check_chain(expected: list[str]) -> None: with self.assertRaises(CollectionCycleError): butler.prepend_collection_chain("chain", "chain2") + # Make sure none of those operations interfered with unrelated chains + self.assertEqual(["a", "b"], list(butler.registry.getCollectionChain("staticchain"))) + class FileDatastoreButlerTests(ButlerTests): """Common tests and specialization of ButlerTests for butlers backed From 9dea0acbc6631722e3ce1c4b57a1cbc52387ab05 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Thu, 4 Apr 2024 10:28:03 -0700 Subject: [PATCH 15/16] Simplify position calculation --- python/lsst/daf/butler/registry/collections/_base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 6f1fba1908..557a34e0fb 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -30,7 +30,6 @@ __all__ = () -import itertools from abc import abstractmethod from collections.abc import Iterable, Iterator, Set from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar, cast @@ -462,14 +461,13 @@ def _insert_collection_chain_rows( starting_position: int, child_keys: list[K], ) -> None: - position = itertools.count(starting_position) rows = [ { "parent": parent_key, "child": child, - "position": next(position), + "position": position, } - for child in child_keys + for position, child in enumerate(child_keys, starting_position) ] self._db.insert(self._tables.collection_chain, *rows) From ce81f8e240f4fcf7bbbfc5efb424f888d2653ab2 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Thu, 4 Apr 2024 11:06:36 -0700 Subject: [PATCH 16/16] Forbid prepend_collection_chain in caching context The expected use cases caching context do not require modifying collection chains, so avoid carrying around cache maintenance code that might never be used. --- .../daf/butler/registry/collections/_base.py | 16 +++++----------- tests/test_butler.py | 11 ++++------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index 557a34e0fb..d069de2129 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -483,6 +483,11 @@ def _remove_collection_chain_rows( def prepend_collection_chain( self, parent_collection_name: str, child_collection_names: list[str] ) -> None: + if self._caching_context.is_enabled: + # Avoid having cache-maintenance code around that is unlikely to + # ever be used. + raise RuntimeError("Chained collection modification not permitted with active caching context.") + self._sanity_check_collection_cycles(parent_collection_name, child_collection_names) child_records = self.resolve_wildcard( @@ -497,8 +502,6 @@ def prepend_collection_chain( self._block_for_concurrency_test() self._insert_collection_chain_rows(parent_key, starting_position, child_keys) - self._refresh_cache_for_key(parent_key) - def _find_lowest_position_in_collection_chain(self, chain_key: K) -> int: """Return the lowest-numbered position in a collection chain, or 0 if the chain is empty. @@ -575,12 +578,3 @@ def _select_pkey_by_name(self, collection_name: str) -> sqlalchemy.Select: - ``type`` : the collection type """ raise NotImplementedError() - - def _refresh_cache_for_key(self, key: K) -> None: - """Refresh the data in the cache for a single collection.""" - cache = self._caching_context.collection_records - if cache is not None: - records = self._fetch_by_key([key]) - if records: - assert len(records) == 1 - cache.add(records[0]) diff --git a/tests/test_butler.py b/tests/test_butler.py index 5b7f49c85e..183d2cf274 100644 --- a/tests/test_butler.py +++ b/tests/test_butler.py @@ -1390,14 +1390,7 @@ def testGetDatasetCollectionCaching(self): def testCollectionChainPrepend(self): butler = self.create_empty_butler(writeable=True) - self._testCollectionChainPrepend(butler) - def testCollectionChainPrependCached(self): - butler = self.create_empty_butler(writeable=True) - with butler._caching_context(): - self._testCollectionChainPrepend(butler) - - def _testCollectionChainPrepend(self, butler: Butler) -> None: butler.registry.registerCollection("chain", CollectionType.CHAINED) runs = ["a", "b", "c", "d"] @@ -1447,6 +1440,10 @@ def check_chain(expected: list[str]) -> None: # Make sure none of those operations interfered with unrelated chains self.assertEqual(["a", "b"], list(butler.registry.getCollectionChain("staticchain"))) + with butler._caching_context(): + with self.assertRaisesRegex(RuntimeError, "Chained collection modification not permitted"): + butler.prepend_collection_chain("chain", "a") + class FileDatastoreButlerTests(ButlerTests): """Common tests and specialization of ButlerTests for butlers backed