From a98b94440c73c221662a05eea19e29ea8f35a139 Mon Sep 17 00:00:00 2001 From: Nick Schrock Date: Wed, 13 Mar 2024 14:02:45 -0400 Subject: [PATCH] Support multi-partitioning in AssetSlice (#20353) ## Summary & Motivation Support multi-partioning in `AssetSlice`. The most notable business logic here is the treatment of the "latest" time window and time windows. Notably this code is aware of whether or not the underlying multi partition definition has a time-windowed partition. If so, methods like latest time window return the full set of partions resident in the _other_ partition dimension affiliated with with the time. ## How I Tested These Changes BK --- .../asset_graph_view/asset_graph_view.py | 117 +++++++++++++++++- .../multi_dimensional_partitions.py | 31 +++-- .../test_latest_time_window.py | 82 ++++++++++++ 3 files changed, 216 insertions(+), 14 deletions(-) diff --git a/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py b/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py index e4368e5270d90..6dd0ef01b99d1 100644 --- a/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py +++ b/python_modules/dagster/dagster/_core/asset_graph_view/asset_graph_view.py @@ -13,8 +13,18 @@ from dagster import _check as check from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset from dagster._core.definitions.events import AssetKey -from dagster._core.definitions.partition import AllPartitionsSubset, StaticPartitionsDefinition +from dagster._core.definitions.multi_dimensional_partitions import ( + MultiPartitionKey, + MultiPartitionsDefinition, + PartitionDimensionDefinition, +) +from dagster._core.definitions.partition import ( + AllPartitionsSubset, + DefaultPartitionsSubset, + StaticPartitionsDefinition, +) from dagster._core.definitions.time_window_partitions import ( + PartitionKeysTimeWindowPartitionsSubset, TimeWindow, TimeWindowPartitionsDefinition, TimeWindowPartitionsSubset, @@ -160,8 +170,9 @@ def compute_intersection_with_partition_keys( @property def time_windows(self) -> Sequence[TimeWindow]: """Get the time windows for the asset slice. Only supports explicitly time-windowed partitions for now.""" - # Only supports explicitly time-windows partitions for now - tw_partitions_def = _required_tw_partitions_def(self._partitions_def) + tw_partitions_def = check.not_none( + self._time_window_partitions_def_in_context(), "Must be time windowed." + ) if isinstance(self._compatible_subset.subset_value, TimeWindowPartitionsSubset): return self._compatible_subset.subset_value.included_time_windows @@ -170,9 +181,46 @@ def time_windows(self) -> Sequence[TimeWindow]: self._asset_graph_view.effective_dt ) return [TimeWindow(datetime.min, last_tw.end)] if last_tw else [] + elif isinstance(self._compatible_subset.subset_value, DefaultPartitionsSubset): + check.inst( + self._partitions_def, + MultiPartitionsDefinition, + "Must be multi-partition if we got here.", + ) + tw_partition_keys = set() + for multi_partition_key in check.is_list( + list(self._compatible_subset.subset_value.get_partition_keys()), + MultiPartitionKey, + "Keys must be multi partition keys.", + ): + tm_partition_key = next(iter(multi_partition_key.keys_by_dimension.values())) + tw_partition_keys.add(tm_partition_key) + + subset_from_tw = tw_partitions_def.subset_with_partition_keys(tw_partition_keys) + check.inst( + subset_from_tw, + (TimeWindowPartitionsSubset, PartitionKeysTimeWindowPartitionsSubset), + "Must be time window subset.", + ) + if isinstance(subset_from_tw, TimeWindowPartitionsSubset): + return subset_from_tw.included_time_windows + elif isinstance(subset_from_tw, PartitionKeysTimeWindowPartitionsSubset): + return subset_from_tw.included_time_windows + else: + check.failed( + f"Unsupported subset value in generated subset {self._compatible_subset.subset_value} created by keys {tw_partition_keys}" + ) else: check.failed(f"Unsupported subset value: {self._compatible_subset.subset_value}") + def _time_window_partitions_def_in_context(self) -> Optional[TimeWindowPartitionsDefinition]: + pd = self._partitions_def + if isinstance(pd, TimeWindowPartitionsDefinition): + return pd + if isinstance(pd, MultiPartitionsDefinition): + return pd.time_window_partitions_def if pd.has_time_window_dimension else None + return None + @property def is_empty(self) -> bool: return self._compatible_subset.size == 0 @@ -372,7 +420,19 @@ def create_latest_time_window_slice(self, asset_key: AssetKey) -> AssetSlice: else self.create_empty_slice(asset_key) ) - # Need to handle dynamic and multi-dimensional partitioning + if isinstance(partitions_def, MultiPartitionsDefinition): + if not partitions_def.has_time_window_dimension: + return self.get_asset_slice(asset_key) + + multi_dim_info = self._get_multi_dim_info(asset_key) + last_tw = multi_dim_info.tw_partition_def.get_last_partition_window(self.effective_dt) + return ( + self._build_multi_partition_slice(asset_key, multi_dim_info, last_tw) + if last_tw + else self.create_empty_slice(asset_key) + ) + + # Need to handle dynamic partitioning check.failed(f"Unsupported partitions_def: {partitions_def}") def create_empty_slice(self, asset_key: AssetKey) -> AssetSlice: @@ -381,6 +441,55 @@ def create_empty_slice(self, asset_key: AssetKey) -> AssetSlice: AssetSubset.empty(asset_key, self._get_partitions_def(asset_key)), ) + class MultiDimInfo(NamedTuple): + tw_dim: PartitionDimensionDefinition + secondary_dim: PartitionDimensionDefinition + + @property + def tw_partition_def(self) -> TimeWindowPartitionsDefinition: + return cast( + TimeWindowPartitionsDefinition, + check.inst(self.tw_dim.partitions_def, TimeWindowPartitionsDefinition), + ) + + @property + def secondary_partition_def(self) -> "PartitionsDefinition": + return self.secondary_dim.partitions_def + + def _get_multi_dim_info(self, asset_key: AssetKey) -> "MultiDimInfo": + partitions_def = cast( + MultiPartitionsDefinition, + check.inst(self._get_partitions_def(asset_key), MultiPartitionsDefinition), + ) + return self.MultiDimInfo( + tw_dim=partitions_def.time_window_dimension, + secondary_dim=partitions_def.secondary_dimension, + ) + + def _build_multi_partition_slice( + self, asset_key: AssetKey, multi_dim_info: MultiDimInfo, last_tw: TimeWindow + ) -> "AssetSlice": + # Note: Potential perf improvement here. There is no way to encode a cartesian product + # in the underlying PartitionsSet. We could add a specialized PartitionsSubset + # subclass that itself composed two PartitionsSubset to avoid materializing the entire + # partitions range. + return self.get_asset_slice(asset_key).compute_intersection_with_partition_keys( + { + MultiPartitionKey( + { + multi_dim_info.tw_dim.name: tw_pk, + multi_dim_info.secondary_dim.name: secondary_pk, + } + ) + for tw_pk in multi_dim_info.tw_partition_def.get_partition_keys_in_time_window( + last_tw + ) + for secondary_pk in multi_dim_info.secondary_partition_def.get_partition_keys( + dynamic_partitions_store=self._queryer, current_time=self.effective_dt + ) + } + ) + def _required_tw_partitions_def( partitions_def: Optional["PartitionsDefinition"], diff --git a/python_modules/dagster/dagster/_core/definitions/multi_dimensional_partitions.py b/python_modules/dagster/dagster/_core/definitions/multi_dimensional_partitions.py index 6e5ee21ccae16..2cc6b6ab6cdad 100644 --- a/python_modules/dagster/dagster/_core/definitions/multi_dimensional_partitions.py +++ b/python_modules/dagster/dagster/_core/definitions/multi_dimensional_partitions.py @@ -395,11 +395,7 @@ def _get_primary_and_secondary_dimension( # the selection of primary/secondary dimension, will need to also update the # serialization of MultiPartitionsSubsets - time_dimensions = [ - dim - for dim in self.partitions_defs - if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) - ] + time_dimensions = self._get_time_window_dims() if len(time_dimensions) == 1: primary_dimension, secondary_dimension = ( time_dimensions[0], @@ -429,16 +425,31 @@ def get_tags_for_partition_key(self, partition_key: str) -> Mapping[str, str]: @property def time_window_dimension(self) -> PartitionDimensionDefinition: - time_window_dims = [ - dim - for dim in self.partitions_defs - if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) - ] + time_window_dims = self._get_time_window_dims() check.invariant( len(time_window_dims) == 1, "Expected exactly one time window partitioned dimension" ) return next(iter(time_window_dims)) + def _get_time_window_dims(self) -> List[PartitionDimensionDefinition]: + return [ + dim + for dim in self.partitions_defs + if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition) + ] + + @property + def has_time_window_dimension(self) -> bool: + return bool(self._get_time_window_dims()) + + @property + def time_window_partitions_def(self) -> TimeWindowPartitionsDefinition: + check.invariant(self.has_time_window_dimension, "Must have time window dimension") + assert isinstance( + self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition + ) # appease pyright + return check.inst(self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition) + def time_window_for_partition_key(self, partition_key: str) -> TimeWindow: if not isinstance(partition_key, MultiPartitionKey): partition_key = self.get_partition_key_from_str(partition_key) diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_latest_time_window.py b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_latest_time_window.py index 74462c43dcb9a..680d7e567f95e 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_latest_time_window.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_graph_view_tests/test_latest_time_window.py @@ -7,8 +7,13 @@ asset, ) from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView, AssetSlice +from dagster._core.definitions.multi_dimensional_partitions import ( + MultiPartitionKey, + MultiPartitionsDefinition, +) from dagster._core.definitions.partition import StaticPartitionsDefinition from dagster._core.definitions.time_window_partitions import DailyPartitionsDefinition, TimeWindow +from dagster._core.execution.context.compute import AssetExecutionContext from dagster._core.instance import DagsterInstance @@ -173,3 +178,80 @@ def up_numbers() -> None: ... asset_graph_view = AssetGraphView.for_test(defs, instance) latest_up_slice = asset_graph_view.create_latest_time_window_slice(up_numbers.key) assert latest_up_slice.compute_partition_keys() == number_keys + + +def test_multi_dimesional_with_time_partition_latest_time_window() -> None: + # starts at 2020-02-01 + daily_partitions_def = DailyPartitionsDefinition( + start_date=pendulum.datetime(2020, 1, 1), end_date=pendulum.datetime(2020, 1, 3) + ) + + static_partitions_def = StaticPartitionsDefinition(["CA", "NY", "MN"]) + + multi_partitions_definition = MultiPartitionsDefinition( + {"daily": daily_partitions_def, "static": static_partitions_def} + ) + + partition_keys = [] + jan_2_keys = [] + for daily_pk in daily_partitions_def.get_partition_keys(): + for static_pk in static_partitions_def.get_partition_keys(): + if daily_pk == "2020-01-02": + jan_2_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk})) + + partition_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk})) + + @asset(partitions_def=multi_partitions_definition) + def multi_dimensional(context: AssetExecutionContext) -> None: ... + + defs = Definitions([multi_dimensional]) + instance = DagsterInstance.ephemeral() + + asset_graph_view_within_partition = AssetGraphView.for_test( + defs, instance, effective_dt=pendulum.datetime(2020, 3, 3) + ) + + md_slice = asset_graph_view_within_partition.get_asset_slice(multi_dimensional.key) + assert md_slice.compute_partition_keys() == set(partition_keys) + last_tw_slice = asset_graph_view_within_partition.create_latest_time_window_slice( + multi_dimensional.key + ) + assert last_tw_slice.compute_partition_keys() == set(jan_2_keys) + assert _tw(last_tw_slice).start == pendulum.datetime(2020, 1, 2) + assert _tw(last_tw_slice).end == pendulum.datetime(2020, 1, 3) + + asset_graph_view_in_past = AssetGraphView.for_test( + defs, instance, effective_dt=pendulum.datetime(2019, 3, 3) + ) + + md_slice_in_past = asset_graph_view_in_past.create_latest_time_window_slice( + multi_dimensional.key + ) + assert md_slice_in_past.compute_partition_keys() == set() + assert not md_slice_in_past.time_windows + + +def test_multi_dimesional_without_time_partition_latest_time_window() -> None: + num_partitions_def = StaticPartitionsDefinition(["1", "2", "3"]) + letter_partitions_def = StaticPartitionsDefinition(["A", "B", "C"]) + + multi_partitions_definition = MultiPartitionsDefinition( + {"num": num_partitions_def, "letter": letter_partitions_def} + ) + + partition_keys = [] + for num_pk in num_partitions_def.get_partition_keys(): + for letter_pk in letter_partitions_def.get_partition_keys(): + partition_keys.append(MultiPartitionKey({"num": num_pk, "letter": letter_pk})) + + @asset(partitions_def=multi_partitions_definition) + def multi_dimensional(context: AssetExecutionContext) -> None: ... + + defs = Definitions([multi_dimensional]) + instance = DagsterInstance.ephemeral() + asset_graph_view = AssetGraphView.for_test(defs, instance) + md_slice = asset_graph_view.get_asset_slice(multi_dimensional.key) + assert md_slice.compute_partition_keys() == set(partition_keys) + assert asset_graph_view.create_latest_time_window_slice( + multi_dimensional.key + ).compute_partition_keys() == set(partition_keys)