From f1dfc2f122e3cd1d3d8eb2cb411ad62fbc861fef Mon Sep 17 00:00:00 2001 From: prha <1040172+prha@users.noreply.github.com> Date: Thu, 24 Oct 2024 20:30:17 -0700 Subject: [PATCH] add partition filter to get_latest_storage_id_by_partition (#25510) ## Summary & Motivation We want to filter the number of partitions we scan for, so that we can optimize the restricted partition case. ## How I Tested These Changes BK --- .../dagster/dagster/_core/instance/__init__.py | 9 +++++++-- .../dagster/dagster/_core/storage/event_log/base.py | 5 ++++- .../dagster/_core/storage/event_log/sql_event_log.py | 7 +++++-- .../dagster/dagster/_core/storage/legacy_storage.py | 7 +++++-- .../storage_tests/utils/event_log_storage.py | 10 ++++++++-- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/python_modules/dagster/dagster/_core/instance/__init__.py b/python_modules/dagster/dagster/_core/instance/__init__.py index aff4da0c78b53..1aa9494a85db2 100644 --- a/python_modules/dagster/dagster/_core/instance/__init__.py +++ b/python_modules/dagster/dagster/_core/instance/__init__.py @@ -2265,13 +2265,18 @@ def get_materialized_partitions( @traced def get_latest_storage_id_by_partition( - self, asset_key: AssetKey, event_type: "DagsterEventType" + self, + asset_key: AssetKey, + event_type: "DagsterEventType", + partitions: Optional[Set[str]] = None, ) -> Mapping[str, int]: """Fetch the latest materialzation storage id for each partition for a given asset key. Returns a mapping of partition to storage id. """ - return self._event_storage.get_latest_storage_id_by_partition(asset_key, event_type) + return self._event_storage.get_latest_storage_id_by_partition( + asset_key, event_type, partitions + ) @traced def get_latest_planned_materialization_info( diff --git a/python_modules/dagster/dagster/_core/storage/event_log/base.py b/python_modules/dagster/dagster/_core/storage/event_log/base.py index 2c823ef1b7358..0bc7beae766af 100644 --- a/python_modules/dagster/dagster/_core/storage/event_log/base.py +++ b/python_modules/dagster/dagster/_core/storage/event_log/base.py @@ -463,7 +463,10 @@ def get_materialized_partitions( @abstractmethod def get_latest_storage_id_by_partition( - self, asset_key: AssetKey, event_type: DagsterEventType + self, + asset_key: AssetKey, + event_type: DagsterEventType, + partitions: Optional[Set[str]] = None, ) -> Mapping[str, int]: pass diff --git a/python_modules/dagster/dagster/_core/storage/event_log/sql_event_log.py b/python_modules/dagster/dagster/_core/storage/event_log/sql_event_log.py index 16b8004c35de5..62a0463e45933 100644 --- a/python_modules/dagster/dagster/_core/storage/event_log/sql_event_log.py +++ b/python_modules/dagster/dagster/_core/storage/event_log/sql_event_log.py @@ -1877,7 +1877,10 @@ def _latest_event_ids_by_partition_subquery( ) def get_latest_storage_id_by_partition( - self, asset_key: AssetKey, event_type: DagsterEventType + self, + asset_key: AssetKey, + event_type: DagsterEventType, + partitions: Optional[Set[str]] = None, ) -> Mapping[str, int]: """Fetch the latest materialzation storage id for each partition for a given asset key. @@ -1886,7 +1889,7 @@ def get_latest_storage_id_by_partition( check.inst_param(asset_key, "asset_key", AssetKey) latest_event_ids_by_partition_subquery = self._latest_event_ids_by_partition_subquery( - asset_key, [event_type] + asset_key, [event_type], asset_partitions=list(partitions) if partitions else None ) latest_event_ids_by_partition = db_select( [ diff --git a/python_modules/dagster/dagster/_core/storage/legacy_storage.py b/python_modules/dagster/dagster/_core/storage/legacy_storage.py index 268f2aef499a0..0a13ebe42956a 100644 --- a/python_modules/dagster/dagster/_core/storage/legacy_storage.py +++ b/python_modules/dagster/dagster/_core/storage/legacy_storage.py @@ -553,10 +553,13 @@ def get_materialized_partitions( ) def get_latest_storage_id_by_partition( - self, asset_key: "AssetKey", event_type: "DagsterEventType" + self, + asset_key: "AssetKey", + event_type: "DagsterEventType", + partitions: Optional[Set[str]] = None, ) -> Mapping[str, int]: return self._storage.event_log_storage.get_latest_storage_id_by_partition( - asset_key, event_type + asset_key, event_type, partitions ) def get_latest_tags_by_partition( diff --git a/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py b/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py index bef2efdb64ef7..940a3511f822b 100644 --- a/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py +++ b/python_modules/dagster/dagster_tests/storage_tests/utils/event_log_storage.py @@ -2634,10 +2634,12 @@ def test_get_latest_storage_ids_by_partition(self, storage, instance): b = AssetKey(["b"]) run_id = make_new_run_id() - def _assert_storage_matches(expected): + def _assert_storage_matches(expected, partition: Optional[str] = None): assert ( storage.get_latest_storage_id_by_partition( - a, DagsterEventType.ASSET_MATERIALIZATION + a, + DagsterEventType.ASSET_MATERIALIZATION, + partitions={partition} if partition else None, ) == expected ) @@ -2679,6 +2681,10 @@ def _store_partition_event(asset_key, partition) -> int: latest_storage_ids["p2"] = _store_partition_event(a, "p2") _assert_storage_matches(latest_storage_ids) + # check that we can filter for specific partitions + _assert_storage_matches({"p1": latest_storage_ids["p1"]}, partition="p1") + _assert_storage_matches({"p2": latest_storage_ids["p2"]}, partition="p2") + # unrelated asset materialized _store_partition_event(b, "p1") _store_partition_event(b, "p2")