From d83cda26d4b247a97c79c24f228d79f65f657fda Mon Sep 17 00:00:00 2001 From: Nick Schrock Date: Wed, 6 Mar 2024 16:40:21 -0500 Subject: [PATCH] cp --- .../asset_graph_view/asset_graph_view.py | 96 +++++++++++++------ .../test_basic_asset_graph_view.py | 13 ++- 2 files changed, 76 insertions(+), 33 deletions(-) diff --git a/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py b/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py index 70b0b5762da0f..12e0d9555e163 100644 --- a/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py +++ b/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py @@ -1,9 +1,11 @@ from datetime import datetime from typing import TYPE_CHECKING, AbstractSet, Mapping, NamedTuple, Optional +from typing_extensions import TypeAlias + from dagster import _check as check from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset -from dagster._core.definitions.events import AssetKey +from dagster._core.definitions.events import AssetKey, AssetKeyPartitionKey from dagster._utils.cached_method import CACHED_METHOD_FIELD_SUFFIX, cached_method if TYPE_CHECKING: @@ -23,12 +25,20 @@ class TemporalContext(NamedTuple): class _AssetSliceCompatibleSubset(ValidAssetSubset): ... -PartitionKey = Optional[str] +PartitionKey: TypeAlias = Optional[str] +AssetPartition: TypeAlias = AssetKeyPartitionKey + + +def _slice_from_subset(asset_graph_view: "AssetGraphView", subset: AssetSubset) -> "AssetSlice": + valid_subset = subset.as_valid( + asset_graph_view.asset_graph.get_partitions_def(subset.asset_key) + ) + return AssetSlice(asset_graph_view, _AssetSliceCompatibleSubset(*valid_subset)) class AssetSlice: # create slots for every cached method - CACHED_METHOD_FIELDS = ["parent_slices", "child_slices"] + CACHED_METHOD_FIELDS = ["get_parent_slice", "get_child_slice", "parent_slices", "child_slices"] # Using slots instead of NamedTuple/Dataclass so we can encapsulate and so we # can create slots for cached methods __slots__ = ["_asset_graph_view", "_compatible_subset"] + [ @@ -47,29 +57,55 @@ def convert_to_valid_asset_subset(self) -> ValidAssetSubset: def materialize_partition_keys(self) -> AbstractSet[PartitionKey]: return {ap.partition_key for ap in self._compatible_subset.asset_partitions} + @property + def asset_key(self) -> AssetKey: + return self._compatible_subset.asset_key + + @property + def parent_keys(self) -> AbstractSet[AssetKey]: + return self._asset_graph_view.asset_graph.get_parents(self.asset_key) + + @property + def _partitions_def(self) -> Optional["PartitionsDefinition"]: + return self._asset_graph_view.asset_graph.get_partitions_def(self.asset_key) + + @cached_method + def get_parent_slice(self, parent_asset_key: AssetKey) -> "AssetSlice": + return self._asset_graph_view.get_parent_asset_slice(parent_asset_key, self) + + @cached_method + def get_child_slice(self, child_asset_key: AssetKey) -> "AssetSlice": + return self._asset_graph_view.get_child_asset_slice(child_asset_key, self) + @property @cached_method def parent_slices(self) -> Mapping[AssetKey, "AssetSlice"]: parent_slices_by_asset_key = {} - for parent_asset_key in self._asset_graph_view._asset_graph.get_parents( # noqa: SLF001 + for parent_asset_key in self._asset_graph_view.asset_graph.get_parents( self._compatible_subset.asset_key ): - parent_slices_by_asset_key[parent_asset_key] = ( - self._asset_graph_view.get_parent_asset_slice(parent_asset_key, self) - ) + parent_slices_by_asset_key[parent_asset_key] = self.get_parent_slice(parent_asset_key) return parent_slices_by_asset_key @property @cached_method def child_slices(self) -> Mapping[AssetKey, "AssetSlice"]: - child_slices_by_asset_key = {} - for parent_asset_key in self._asset_graph_view._asset_graph.get_parents( # noqa: SLF001 - self._compatible_subset.asset_key - ): - child_slices_by_asset_key[parent_asset_key] = ( - self._asset_graph_view.get_parent_asset_slice(parent_asset_key, self) - ) - return child_slices_by_asset_key + return { + parent_asset_key: self.get_child_slice(parent_asset_key) + for parent_asset_key in self._asset_graph_view.asset_graph.get_children(self.asset_key) + } + + def of_materialized_partition_keys( + self, partition_keys: AbstractSet[PartitionKey] + ) -> "AssetSlice": + return _slice_from_subset( + self._asset_graph_view, + AssetSubset.from_asset_partitions_set( + self.asset_key, + self._partitions_def, + {AssetPartition(self.asset_key, partition_key) for partition_key in partition_keys}, + ), + ) class AssetGraphView: @@ -139,7 +175,7 @@ def last_event_id(self) -> Optional[int]: return self._temporal_context.last_event_id @property - def _asset_graph(self) -> "AssetGraph": + def asset_graph(self) -> "AssetGraph": return self._stale_resolver.asset_graph @property @@ -147,43 +183,41 @@ def _queryer(self) -> "CachingInstanceQueryer": return self._stale_resolver.instance_queryer def _get_partitions_def(self, asset_key: "AssetKey") -> Optional["PartitionsDefinition"]: - return self._asset_graph.get_partitions_def(asset_key) + return self.asset_graph.get_partitions_def(asset_key) def get_asset_slice(self, asset_key: "AssetKey") -> "AssetSlice": - partitions_def = self._get_partitions_def(asset_key) - return self._slice_from_subset( + return _slice_from_subset( + self, AssetSubset.all( asset_key=asset_key, - partitions_def=partitions_def, + partitions_def=self._get_partitions_def(asset_key), dynamic_partitions_store=self._queryer, current_time=self.effective_dt, - ) + ), ) def get_parent_asset_slice( self, parent_asset_key: AssetKey, asset_slice: AssetSlice ) -> AssetSlice: - return self._slice_from_subset( - self._asset_graph.get_parent_asset_subset( + return _slice_from_subset( + self, + self.asset_graph.get_parent_asset_subset( dynamic_partitions_store=self._queryer, parent_asset_key=parent_asset_key, child_asset_subset=asset_slice.convert_to_valid_asset_subset(), current_time=self.effective_dt, - ) + ), ) def get_child_asset_slice( self, child_asset_key: "AssetKey", asset_slice: AssetSlice ) -> "AssetSlice": - return self._slice_from_subset( - self._asset_graph.get_child_asset_subset( + return _slice_from_subset( + self, + self.asset_graph.get_child_asset_subset( dynamic_partitions_store=self._queryer, child_asset_key=child_asset_key, current_time=self.effective_dt, parent_asset_subset=asset_slice.convert_to_valid_asset_subset(), - ) + ), ) - - def _slice_from_subset(self: "AssetGraphView", subset: AssetSubset) -> AssetSlice: - valid_subset = subset.as_valid(self._get_partitions_def(subset.asset_key)) - return AssetSlice(self, _AssetSliceCompatibleSubset(*valid_subset)) diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_basic_asset_graph_view.py b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_basic_asset_graph_view.py index 965eb6d0e1c8e..0a60c7dc24e60 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_basic_asset_graph_view.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_basic_asset_graph_view.py @@ -23,8 +23,7 @@ def an_asset() -> None: ... # hiding stale resolver deliberately but want to test instance object identity assert asset_graph_view_t0._stale_resolver.instance_queryer.instance is instance # noqa: SLF001 - # also hiding asset graph deliberately but want to test asset keys - assert asset_graph_view_t0._asset_graph.all_asset_keys == {an_asset.key} # noqa: SLF001 + assert asset_graph_view_t0.asset_graph.all_asset_keys == {an_asset.key} def test_slice_traversal_static_partitions() -> None: @@ -73,3 +72,13 @@ def down_letters() -> None: ... "2", "3", } + + # subset of up to subset of down + assert up_slice.of_materialized_partition_keys({"2"}).child_slices[ + down_letters.key + ].materialize_partition_keys() == {"b"} + + # subset of down to subset of up + assert down_slice.of_materialized_partition_keys({"b"}).parent_slices[ + up_numbers.key + ].materialize_partition_keys() == {"2"}