Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-partitioning in AssetSlice #20353

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,18 @@
from dagster import _check as check
from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.partition import AllPartitionsSubset, StaticPartitionsDefinition
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionKey,
MultiPartitionsDefinition,
PartitionDimensionDefinition,
)
from dagster._core.definitions.partition import (
AllPartitionsSubset,
DefaultPartitionsSubset,
StaticPartitionsDefinition,
)
from dagster._core.definitions.time_window_partitions import (
PartitionKeysTimeWindowPartitionsSubset,
TimeWindow,
TimeWindowPartitionsDefinition,
TimeWindowPartitionsSubset,
Expand Down Expand Up @@ -160,8 +170,9 @@ def compute_intersection_with_partition_keys(
@property
def time_windows(self) -> Sequence[TimeWindow]:
"""Get the time windows for the asset slice. Only supports explicitly time-windowed partitions for now."""
# Only supports explicitly time-windows partitions for now
tw_partitions_def = _required_tw_partitions_def(self._partitions_def)
tw_partitions_def = check.not_none(
self._time_window_partitions_def_in_context(), "Must be time windowed."
)

if isinstance(self._compatible_subset.subset_value, TimeWindowPartitionsSubset):
return self._compatible_subset.subset_value.included_time_windows
Expand All @@ -170,9 +181,46 @@ def time_windows(self) -> Sequence[TimeWindow]:
self._asset_graph_view.effective_dt
)
return [TimeWindow(datetime.min, last_tw.end)] if last_tw else []
elif isinstance(self._compatible_subset.subset_value, DefaultPartitionsSubset):
check.inst(
self._partitions_def,
MultiPartitionsDefinition,
"Must be multi-partition if we got here.",
)
tw_partition_keys = set()
for multi_partition_key in check.is_list(
list(self._compatible_subset.subset_value.get_partition_keys()),
MultiPartitionKey,
"Keys must be multi partition keys.",
):
tm_partition_key = next(iter(multi_partition_key.keys_by_dimension.values()))
tw_partition_keys.add(tm_partition_key)

subset_from_tw = tw_partitions_def.subset_with_partition_keys(tw_partition_keys)
check.inst(
subset_from_tw,
(TimeWindowPartitionsSubset, PartitionKeysTimeWindowPartitionsSubset),
"Must be time window subset.",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is going to become a core part of the system, IMO we should put the necessary time window dimension accessors on MultiPartitionsDefinition instead of having this one-off MultiDimInfo class.

EDIT: Looks like you did actually add to MultiPartitionsDefinition?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added some helper methods to MultiPartitionsDefinition. I think we should add to the core layer in a follow up, alongside an effort to represent multi-partioning in the underlying PartitionsSubset more efficiently.

if isinstance(subset_from_tw, TimeWindowPartitionsSubset):
return subset_from_tw.included_time_windows
elif isinstance(subset_from_tw, PartitionKeysTimeWindowPartitionsSubset):
return subset_from_tw.included_time_windows
else:
check.failed(
f"Unsupported subset value in generated subset {self._compatible_subset.subset_value} created by keys {tw_partition_keys}"
)
else:
check.failed(f"Unsupported subset value: {self._compatible_subset.subset_value}")

def _time_window_partitions_def_in_context(self) -> Optional[TimeWindowPartitionsDefinition]:
pd = self._partitions_def
if isinstance(pd, TimeWindowPartitionsDefinition):
return pd
if isinstance(pd, MultiPartitionsDefinition):
return pd.time_window_partitions_def if pd.has_time_window_dimension else None
return None

@property
def is_empty(self) -> bool:
return self._compatible_subset.size == 0
Expand Down Expand Up @@ -372,7 +420,19 @@ def create_latest_time_window_slice(self, asset_key: AssetKey) -> AssetSlice:
else self.create_empty_slice(asset_key)
)

# Need to handle dynamic and multi-dimensional partitioning
if isinstance(partitions_def, MultiPartitionsDefinition):
if not partitions_def.has_time_window_dimension:
return self.get_asset_slice(asset_key)

multi_dim_info = self._get_multi_dim_info(asset_key)
last_tw = multi_dim_info.tw_partition_def.get_last_partition_window(self.effective_dt)
return (
self._build_multi_partition_slice(asset_key, multi_dim_info, last_tw)
if last_tw
else self.create_empty_slice(asset_key)
)

# Need to handle dynamic partitioning
check.failed(f"Unsupported partitions_def: {partitions_def}")

def create_empty_slice(self, asset_key: AssetKey) -> AssetSlice:
Expand All @@ -381,6 +441,55 @@ def create_empty_slice(self, asset_key: AssetKey) -> AssetSlice:
AssetSubset.empty(asset_key, self._get_partitions_def(asset_key)),
)

class MultiDimInfo(NamedTuple):
tw_dim: PartitionDimensionDefinition
secondary_dim: PartitionDimensionDefinition

@property
def tw_partition_def(self) -> TimeWindowPartitionsDefinition:
return cast(
TimeWindowPartitionsDefinition,
check.inst(self.tw_dim.partitions_def, TimeWindowPartitionsDefinition),
)

@property
def secondary_partition_def(self) -> "PartitionsDefinition":
return self.secondary_dim.partitions_def

def _get_multi_dim_info(self, asset_key: AssetKey) -> "MultiDimInfo":
partitions_def = cast(
MultiPartitionsDefinition,
check.inst(self._get_partitions_def(asset_key), MultiPartitionsDefinition),
)
return self.MultiDimInfo(
tw_dim=partitions_def.time_window_dimension,
secondary_dim=partitions_def.secondary_dimension,
)

def _build_multi_partition_slice(
self, asset_key: AssetKey, multi_dim_info: MultiDimInfo, last_tw: TimeWindow
) -> "AssetSlice":
# Note: Potential perf improvement here. There is no way to encode a cartesian product
# in the underlying PartitionsSet. We could add a specialized PartitionsSubset
# subclass that itself composed two PartitionsSubset to avoid materializing the entire
# partitions range.
return self.get_asset_slice(asset_key).compute_intersection_with_partition_keys(
{
MultiPartitionKey(
{
multi_dim_info.tw_dim.name: tw_pk,
multi_dim_info.secondary_dim.name: secondary_pk,
}
)
for tw_pk in multi_dim_info.tw_partition_def.get_partition_keys_in_time_window(
last_tw
)
for secondary_pk in multi_dim_info.secondary_partition_def.get_partition_keys(
dynamic_partitions_store=self._queryer, current_time=self.effective_dt
)
}
)


def _required_tw_partitions_def(
partitions_def: Optional["PartitionsDefinition"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,7 @@ def _get_primary_and_secondary_dimension(
# the selection of primary/secondary dimension, will need to also update the
# serialization of MultiPartitionsSubsets

time_dimensions = [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]
time_dimensions = self._get_time_window_dims()
if len(time_dimensions) == 1:
primary_dimension, secondary_dimension = (
time_dimensions[0],
Expand Down Expand Up @@ -429,16 +425,31 @@ def get_tags_for_partition_key(self, partition_key: str) -> Mapping[str, str]:

@property
def time_window_dimension(self) -> PartitionDimensionDefinition:
time_window_dims = [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]
time_window_dims = self._get_time_window_dims()
check.invariant(
len(time_window_dims) == 1, "Expected exactly one time window partitioned dimension"
)
return next(iter(time_window_dims))

def _get_time_window_dims(self) -> List[PartitionDimensionDefinition]:
return [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]

@property
def has_time_window_dimension(self) -> bool:
return bool(self._get_time_window_dims())

@property
def time_window_partitions_def(self) -> TimeWindowPartitionsDefinition:
check.invariant(self.has_time_window_dimension, "Must have time window dimension")
assert isinstance(
self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition
) # appease pyright
return check.inst(self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition)

def time_window_for_partition_key(self, partition_key: str) -> TimeWindow:
if not isinstance(partition_key, MultiPartitionKey):
partition_key = self.get_partition_key_from_str(partition_key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
asset,
)
from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView, AssetSlice
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionKey,
MultiPartitionsDefinition,
)
from dagster._core.definitions.partition import StaticPartitionsDefinition
from dagster._core.definitions.time_window_partitions import DailyPartitionsDefinition, TimeWindow
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.instance import DagsterInstance


Expand Down Expand Up @@ -173,3 +178,80 @@ def up_numbers() -> None: ...
asset_graph_view = AssetGraphView.for_test(defs, instance)
latest_up_slice = asset_graph_view.create_latest_time_window_slice(up_numbers.key)
assert latest_up_slice.compute_partition_keys() == number_keys


def test_multi_dimesional_with_time_partition_latest_time_window() -> None:
# starts at 2020-02-01
daily_partitions_def = DailyPartitionsDefinition(
start_date=pendulum.datetime(2020, 1, 1), end_date=pendulum.datetime(2020, 1, 3)
)

static_partitions_def = StaticPartitionsDefinition(["CA", "NY", "MN"])

multi_partitions_definition = MultiPartitionsDefinition(
{"daily": daily_partitions_def, "static": static_partitions_def}
)

partition_keys = []
jan_2_keys = []
for daily_pk in daily_partitions_def.get_partition_keys():
for static_pk in static_partitions_def.get_partition_keys():
if daily_pk == "2020-01-02":
jan_2_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk}))

partition_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk}))

@asset(partitions_def=multi_partitions_definition)
def multi_dimensional(context: AssetExecutionContext) -> None: ...

defs = Definitions([multi_dimensional])
instance = DagsterInstance.ephemeral()

asset_graph_view_within_partition = AssetGraphView.for_test(
defs, instance, effective_dt=pendulum.datetime(2020, 3, 3)
)

md_slice = asset_graph_view_within_partition.get_asset_slice(multi_dimensional.key)
assert md_slice.compute_partition_keys() == set(partition_keys)
last_tw_slice = asset_graph_view_within_partition.create_latest_time_window_slice(
multi_dimensional.key
)
assert last_tw_slice.compute_partition_keys() == set(jan_2_keys)
assert _tw(last_tw_slice).start == pendulum.datetime(2020, 1, 2)
assert _tw(last_tw_slice).end == pendulum.datetime(2020, 1, 3)

asset_graph_view_in_past = AssetGraphView.for_test(
defs, instance, effective_dt=pendulum.datetime(2019, 3, 3)
)

md_slice_in_past = asset_graph_view_in_past.create_latest_time_window_slice(
multi_dimensional.key
)
assert md_slice_in_past.compute_partition_keys() == set()
assert not md_slice_in_past.time_windows


def test_multi_dimesional_without_time_partition_latest_time_window() -> None:
num_partitions_def = StaticPartitionsDefinition(["1", "2", "3"])
letter_partitions_def = StaticPartitionsDefinition(["A", "B", "C"])

multi_partitions_definition = MultiPartitionsDefinition(
{"num": num_partitions_def, "letter": letter_partitions_def}
)

partition_keys = []
for num_pk in num_partitions_def.get_partition_keys():
for letter_pk in letter_partitions_def.get_partition_keys():
partition_keys.append(MultiPartitionKey({"num": num_pk, "letter": letter_pk}))

@asset(partitions_def=multi_partitions_definition)
def multi_dimensional(context: AssetExecutionContext) -> None: ...

defs = Definitions([multi_dimensional])
instance = DagsterInstance.ephemeral()
asset_graph_view = AssetGraphView.for_test(defs, instance)
md_slice = asset_graph_view.get_asset_slice(multi_dimensional.key)
assert md_slice.compute_partition_keys() == set(partition_keys)
assert asset_graph_view.create_latest_time_window_slice(
multi_dimensional.key
).compute_partition_keys() == set(partition_keys)