From 52bf7bd3c115cca1850bd58797398c758cbea6c7 Mon Sep 17 00:00:00 2001 From: Nick Schrock Date: Thu, 7 Mar 2024 18:56:12 -0500 Subject: [PATCH] Support multi-partitioning in AssetSlice remove comment fix test --- .../asset_graph_view/asset_graph_view.py | 109 +++++++++++++++++- .../multi_dimensional_partitions.py | 31 +++-- .../test_latest_time_window.py | 82 +++++++++++++ 3 files changed, 209 insertions(+), 13 deletions(-) diff --git a/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py b/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py index ba0eae2cc12a5..26bc9c909f5b1 100644 --- a/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py +++ b/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py @@ -13,8 +13,18 @@ from dagster import _check as check from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset from dagster._core.definitions.events import AssetKey -from dagster._core.definitions.partition import AllPartitionsSubset, StaticPartitionsDefinition +from dagster._core.definitions.multi_dimensional_partitions import ( + MultiPartitionKey, + MultiPartitionsDefinition, + PartitionDimensionDefinition, +) +from dagster._core.definitions.partition import ( + AllPartitionsSubset, + DefaultPartitionsSubset, + StaticPartitionsDefinition, +) from dagster._core.definitions.time_window_partitions import ( + PartitionKeysTimeWindowPartitionsSubset, TimeWindow, TimeWindowPartitionsDefinition, TimeWindowPartitionsSubset, @@ -160,7 +170,9 @@ def compute_intersection_with_partition_keys( @property def time_windows(self) -> Sequence[TimeWindow]: # TODO: support this for all subset values - tw_partitions_def = _required_tw_partitions_def(self._partitions_def) + tw_partitions_def = self._time_window_partitions_def_in_context() + check.inst(tw_partitions_def, TimeWindowPartitionsDefinition, "Must be time windowed.") + assert isinstance(tw_partitions_def, TimeWindowPartitionsDefinition) # appease type checker if isinstance(self._compatible_subset.subset_value, TimeWindowPartitionsSubset): return self._compatible_subset.subset_value.included_time_windows @@ -169,9 +181,41 @@ def time_windows(self) -> Sequence[TimeWindow]: self._asset_graph_view.effective_dt ) return [TimeWindow(datetime.min, last_tw.end)] if last_tw else [] + elif isinstance(self._compatible_subset.subset_value, DefaultPartitionsSubset): + check.inst( + self._partitions_def, + MultiPartitionsDefinition, + "Must be multi-partition if we got here.", + ) + tw_partition_keys = set() + for multi_partition_key in self._compatible_subset.subset_value.get_partition_keys(): + check.inst(multi_partition_key, MultiPartitionKey, "Must be multi partition key.") + assert isinstance(multi_partition_key, MultiPartitionKey) # appease type checker + tm_partition_key = next(iter(multi_partition_key.keys_by_dimension.values())) + tw_partition_keys.add(tm_partition_key) + subset_from_tw = tw_partitions_def.subset_with_partition_keys(tw_partition_keys) + check.inst( + subset_from_tw, + (TimeWindowPartitionsSubset, PartitionKeysTimeWindowPartitionsSubset), + "Must be time window subset.", + ) + if isinstance(subset_from_tw, TimeWindowPartitionsSubset): + return subset_from_tw.included_time_windows + elif isinstance(subset_from_tw, PartitionKeysTimeWindowPartitionsSubset): + return subset_from_tw.included_time_windows + else: + check.failed(f"Unsupported subset value: {self._compatible_subset.subset_value}") else: check.failed(f"Unsupported subset value: {self._compatible_subset.subset_value}") + def _time_window_partitions_def_in_context(self) -> Optional[TimeWindowPartitionsDefinition]: + pd = self._partitions_def + if isinstance(pd, TimeWindowPartitionsDefinition): + return pd + if isinstance(pd, MultiPartitionsDefinition): + return pd.time_window_partitions_def if pd.has_time_window_dimension else None + return None + @property def is_empty(self) -> bool: return self._compatible_subset.size == 0 @@ -371,7 +415,19 @@ def create_latest_time_window_slice(self, asset_key: AssetKey) -> AssetSlice: else self.create_empty_slice(asset_key) ) - # Need to handle dynamic and multi-dimensional partitioning + if isinstance(partitions_def, MultiPartitionsDefinition): + if not partitions_def.has_time_window_dimension: + return self.get_asset_slice(asset_key) + + multi_dim_info = self._get_multi_dim_info(asset_key) + last_tw = multi_dim_info.tw_partition_def.get_last_partition_window(self.effective_dt) + return ( + self._build_multi_partition_slice(asset_key, multi_dim_info, last_tw) + if last_tw + else self.create_empty_slice(asset_key) + ) + + # Need to handle dynamic partitioning check.failed(f"Unsupported partitions_def: {partitions_def}") def create_empty_slice(self, asset_key: AssetKey) -> AssetSlice: @@ -380,6 +436,53 @@ def create_empty_slice(self, asset_key: AssetKey) -> AssetSlice: AssetSubset.empty(asset_key, self._get_partitions_def(asset_key)), ) + class MultiDimInfo(NamedTuple): + tw_dim: PartitionDimensionDefinition + secondary_dim: PartitionDimensionDefinition + + @property + def tw_partition_def(self) -> TimeWindowPartitionsDefinition: + check.inst(self.tw_dim.partitions_def, TimeWindowPartitionsDefinition) + assert isinstance( + self.tw_dim.partitions_def, TimeWindowPartitionsDefinition + ) # appease pyright + return self.tw_dim.partitions_def + + @property + def secondary_partition_def(self) -> "PartitionsDefinition": + return self.secondary_dim.partitions_def + + def _get_multi_dim_info(self, asset_key: AssetKey) -> "MultiDimInfo": + partitions_def = self._get_partitions_def(asset_key) + check.inst(partitions_def, MultiPartitionsDefinition) + assert isinstance(partitions_def, MultiPartitionsDefinition) # appease pyright + return self.MultiDimInfo( + tw_dim=partitions_def.time_window_dimension, + secondary_dim=partitions_def.secondary_dimension, + ) + + def _build_multi_partition_slice( + self, asset_key: AssetKey, multi_dim_info: MultiDimInfo, last_tw: TimeWindow + ) -> "AssetSlice": + # Note: Potential perf improvement here. There is no way to encode a cartesian product + # in the underlying PartitionsSet. We could add a specialized PartitionsSubset + # subclass that itself composed two PartitionsSubset to avoid materializing the entire + # partitions range. + return self.get_asset_slice(asset_key).compute_intersection_with_partition_keys( + { + MultiPartitionKey( + { + multi_dim_info.tw_dim.name: tw_pk, + multi_dim_info.secondary_dim.name: secondary_pk, + } + ) + for tw_pk in multi_dim_info.tw_partition_def.get_partition_keys_in_time_window( + last_tw + ) + for secondary_pk in multi_dim_info.secondary_partition_def.get_partition_keys() + } + ) + def _required_tw_partitions_def( partitions_def: Optional["PartitionsDefinition"], diff --git a/python_modules/dagster/dagster/_core/definitions/multi_dimensional_partitions.py b/python_modules/dagster/dagster/_core/definitions/multi_dimensional_partitions.py index 6e5ee21ccae16..2cc6b6ab6cdad 100644 --- a/python_modules/dagster/dagster/_core/definitions/multi_dimensional_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/multi_dimensional_partitions.py @@ -395,11 +395,7 @@ def _get_primary_and_secondary_dimension( # the selection of primary/secondary dimension, will need to also update the # serialization of MultiPartitionsSubsets - time_dimensions = [ - dim - for dim in self.partitions_defs - if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) - ] + time_dimensions = self._get_time_window_dims() if len(time_dimensions) == 1: primary_dimension, secondary_dimension = ( time_dimensions[0], @@ -429,16 +425,31 @@ def get_tags_for_partition_key(self, partition_key: str) -> Mapping[str, str]: @property def time_window_dimension(self) -> PartitionDimensionDefinition: - time_window_dims = [ - dim - for dim in self.partitions_defs - if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) - ] + time_window_dims = self._get_time_window_dims() check.invariant( len(time_window_dims) == 1, "Expected exactly one time window partitioned dimension" ) return next(iter(time_window_dims)) + def _get_time_window_dims(self) -> List[PartitionDimensionDefinition]: + return [ + dim + for dim in self.partitions_defs + if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) + ] + + @property + def has_time_window_dimension(self) -> bool: + return bool(self._get_time_window_dims()) + + @property + def time_window_partitions_def(self) -> TimeWindowPartitionsDefinition: + check.invariant(self.has_time_window_dimension, "Must have time window dimension") + assert isinstance( + self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition + ) # appease pyright + return check.inst(self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition) + def time_window_for_partition_key(self, partition_key: str) -> TimeWindow: if not isinstance(partition_key, MultiPartitionKey): partition_key = self.get_partition_key_from_str(partition_key) diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_latest_time_window.py b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_latest_time_window.py index 74462c43dcb9a..680d7e567f95e 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_latest_time_window.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_latest_time_window.py @@ -7,8 +7,13 @@ asset, ) from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView, AssetSlice +from dagster._core.definitions.multi_dimensional_partitions import ( + MultiPartitionKey, + MultiPartitionsDefinition, +) from dagster._core.definitions.partition import StaticPartitionsDefinition from dagster._core.definitions.time_window_partitions import DailyPartitionsDefinition, TimeWindow +from dagster._core.execution.context.compute import AssetExecutionContext from dagster._core.instance import DagsterInstance @@ -173,3 +178,80 @@ def up_numbers() -> None: ... asset_graph_view = AssetGraphView.for_test(defs, instance) latest_up_slice = asset_graph_view.create_latest_time_window_slice(up_numbers.key) assert latest_up_slice.compute_partition_keys() == number_keys + + +def test_multi_dimesional_with_time_partition_latest_time_window() -> None: + # starts at 2020-02-01 + daily_partitions_def = DailyPartitionsDefinition( + start_date=pendulum.datetime(2020, 1, 1), end_date=pendulum.datetime(2020, 1, 3) + ) + + static_partitions_def = StaticPartitionsDefinition(["CA", "NY", "MN"]) + + multi_partitions_definition = MultiPartitionsDefinition( + {"daily": daily_partitions_def, "static": static_partitions_def} + ) + + partition_keys = [] + jan_2_keys = [] + for daily_pk in daily_partitions_def.get_partition_keys(): + for static_pk in static_partitions_def.get_partition_keys(): + if daily_pk == "2020-01-02": + jan_2_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk})) + + partition_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk})) + + @asset(partitions_def=multi_partitions_definition) + def multi_dimensional(context: AssetExecutionContext) -> None: ... + + defs = Definitions([multi_dimensional]) + instance = DagsterInstance.ephemeral() + + asset_graph_view_within_partition = AssetGraphView.for_test( + defs, instance, effective_dt=pendulum.datetime(2020, 3, 3) + ) + + md_slice = asset_graph_view_within_partition.get_asset_slice(multi_dimensional.key) + assert md_slice.compute_partition_keys() == set(partition_keys) + last_tw_slice = asset_graph_view_within_partition.create_latest_time_window_slice( + multi_dimensional.key + ) + assert last_tw_slice.compute_partition_keys() == set(jan_2_keys) + assert _tw(last_tw_slice).start == pendulum.datetime(2020, 1, 2) + assert _tw(last_tw_slice).end == pendulum.datetime(2020, 1, 3) + + asset_graph_view_in_past = AssetGraphView.for_test( + defs, instance, effective_dt=pendulum.datetime(2019, 3, 3) + ) + + md_slice_in_past = asset_graph_view_in_past.create_latest_time_window_slice( + multi_dimensional.key + ) + assert md_slice_in_past.compute_partition_keys() == set() + assert not md_slice_in_past.time_windows + + +def test_multi_dimesional_without_time_partition_latest_time_window() -> None: + num_partitions_def = StaticPartitionsDefinition(["1", "2", "3"]) + letter_partitions_def = StaticPartitionsDefinition(["A", "B", "C"]) + + multi_partitions_definition = MultiPartitionsDefinition( + {"num": num_partitions_def, "letter": letter_partitions_def} + ) + + partition_keys = [] + for num_pk in num_partitions_def.get_partition_keys(): + for letter_pk in letter_partitions_def.get_partition_keys(): + partition_keys.append(MultiPartitionKey({"num": num_pk, "letter": letter_pk})) + + @asset(partitions_def=multi_partitions_definition) + def multi_dimensional(context: AssetExecutionContext) -> None: ... + + defs = Definitions([multi_dimensional]) + instance = DagsterInstance.ephemeral() + asset_graph_view = AssetGraphView.for_test(defs, instance) + md_slice = asset_graph_view.get_asset_slice(multi_dimensional.key) + assert md_slice.compute_partition_keys() == set(partition_keys) + assert asset_graph_view.create_latest_time_window_slice( + multi_dimensional.key + ).compute_partition_keys() == set(partition_keys)