Skip to content

Commit

Permalink
Support multi-partitioning in AssetSlice
Browse files Browse the repository at this point in the history
remove comment

fix test
  • Loading branch information
schrockn committed Mar 12, 2024
1 parent 1077b0d commit 7e8c5cb
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,18 @@
from dagster import _check as check
from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.partition import AllPartitionsSubset, StaticPartitionsDefinition
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionKey,
MultiPartitionsDefinition,
PartitionDimensionDefinition,
)
from dagster._core.definitions.partition import (
AllPartitionsSubset,
DefaultPartitionsSubset,
StaticPartitionsDefinition,
)
from dagster._core.definitions.time_window_partitions import (
PartitionKeysTimeWindowPartitionsSubset,
TimeWindow,
TimeWindowPartitionsDefinition,
TimeWindowPartitionsSubset,
Expand Down Expand Up @@ -160,7 +170,9 @@ def compute_intersection_with_partition_keys(
@property
def time_windows(self) -> Sequence[TimeWindow]:
# TODO: support this for all subset values
tw_partitions_def = _required_tw_partitions_def(self._partitions_def)
tw_partitions_def = self._time_window_partitions_def_in_context()
check.inst(tw_partitions_def, TimeWindowPartitionsDefinition, "Must be time windowed.")
assert isinstance(tw_partitions_def, TimeWindowPartitionsDefinition) # appease type checker

if isinstance(self._compatible_subset.subset_value, TimeWindowPartitionsSubset):
return self._compatible_subset.subset_value.included_time_windows
Expand All @@ -169,9 +181,41 @@ def time_windows(self) -> Sequence[TimeWindow]:
self._asset_graph_view.effective_dt
)
return [TimeWindow(datetime.min, last_tw.end)] if last_tw else []
elif isinstance(self._compatible_subset.subset_value, DefaultPartitionsSubset):
check.inst(
self._partitions_def,
MultiPartitionsDefinition,
"Must be multi-partition if we got here.",
)
tw_partition_keys = set()
for multi_partition_key in self._compatible_subset.subset_value.get_partition_keys():
check.inst(multi_partition_key, MultiPartitionKey, "Must be multi partition key.")
assert isinstance(multi_partition_key, MultiPartitionKey) # appease type checker
tm_partition_key = next(iter(multi_partition_key.keys_by_dimension.values()))
tw_partition_keys.add(tm_partition_key)
subset_from_tw = tw_partitions_def.subset_with_partition_keys(tw_partition_keys)
check.inst(
subset_from_tw,
(TimeWindowPartitionsSubset, PartitionKeysTimeWindowPartitionsSubset),
"Must be time window subset.",
)
if isinstance(subset_from_tw, TimeWindowPartitionsSubset):
return subset_from_tw.included_time_windows
elif isinstance(subset_from_tw, PartitionKeysTimeWindowPartitionsSubset):
return subset_from_tw.included_time_windows
else:
check.failed(f"Unsupported subset value: {self._compatible_subset.subset_value}")
else:
check.failed(f"Unsupported subset value: {self._compatible_subset.subset_value}")

def _time_window_partitions_def_in_context(self) -> Optional[TimeWindowPartitionsDefinition]:
pd = self._partitions_def
if isinstance(pd, TimeWindowPartitionsDefinition):
return pd
if isinstance(pd, MultiPartitionsDefinition):
return pd.time_window_partitions_def if pd.has_time_window_dimension else None
return None

@property
def is_empty(self) -> bool:
return self._compatible_subset.size == 0
Expand Down Expand Up @@ -371,7 +415,19 @@ def create_latest_time_window_slice(self, asset_key: AssetKey) -> AssetSlice:
else self.create_empty_slice(asset_key)
)

# Need to handle dynamic and multi-dimensional partitioning
if isinstance(partitions_def, MultiPartitionsDefinition):
if not partitions_def.has_time_window_dimension:
return self.get_asset_slice(asset_key)

multi_dim_info = self._get_multi_dim_info(asset_key)
last_tw = multi_dim_info.tw_partition_def.get_last_partition_window(self.effective_dt)
return (
self._build_multi_partition_slice(asset_key, multi_dim_info, last_tw)
if last_tw
else self.create_empty_slice(asset_key)
)

# Need to handle dynamic partitioning
check.failed(f"Unsupported partitions_def: {partitions_def}")

def create_empty_slice(self, asset_key: AssetKey) -> AssetSlice:
Expand All @@ -380,6 +436,53 @@ def create_empty_slice(self, asset_key: AssetKey) -> AssetSlice:
AssetSubset.empty(asset_key, self._get_partitions_def(asset_key)),
)

class MultiDimInfo(NamedTuple):
tw_dim: PartitionDimensionDefinition
secondary_dim: PartitionDimensionDefinition

@property
def tw_partition_def(self) -> TimeWindowPartitionsDefinition:
check.inst(self.tw_dim.partitions_def, TimeWindowPartitionsDefinition)
assert isinstance(
self.tw_dim.partitions_def, TimeWindowPartitionsDefinition
) # appease pyright
return self.tw_dim.partitions_def

@property
def secondary_partition_def(self) -> "PartitionsDefinition":
return self.secondary_dim.partitions_def

def _get_multi_dim_info(self, asset_key: AssetKey) -> "MultiDimInfo":
partitions_def = self._get_partitions_def(asset_key)
check.inst(partitions_def, MultiPartitionsDefinition)
assert isinstance(partitions_def, MultiPartitionsDefinition) # appease pyright
return self.MultiDimInfo(
tw_dim=partitions_def.time_window_dimension,
secondary_dim=partitions_def.secondary_dimension,
)

def _build_multi_partition_slice(
self, asset_key: AssetKey, multi_dim_info: MultiDimInfo, last_tw: TimeWindow
) -> "AssetSlice":
# Note: Potential perf improvement here. There is no way to encode a cartesian product
# in the underlying PartitionsSet. We could add a specialized PartitionsSubset
# subclass that itself composed two PartitionsSubset to avoid materializing the entire
# partitions range.
return self.get_asset_slice(asset_key).compute_intersection_with_partition_keys(
{
MultiPartitionKey(
{
multi_dim_info.tw_dim.name: tw_pk,
multi_dim_info.secondary_dim.name: secondary_pk,
}
)
for tw_pk in multi_dim_info.tw_partition_def.get_partition_keys_in_time_window(
last_tw
)
for secondary_pk in multi_dim_info.secondary_partition_def.get_partition_keys()
}
)


def _required_tw_partitions_def(
partitions_def: Optional["PartitionsDefinition"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,7 @@ def _get_primary_and_secondary_dimension(
# the selection of primary/secondary dimension, will need to also update the
# serialization of MultiPartitionsSubsets

time_dimensions = [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]
time_dimensions = self._get_time_window_dims()
if len(time_dimensions) == 1:
primary_dimension, secondary_dimension = (
time_dimensions[0],
Expand Down Expand Up @@ -429,16 +425,31 @@ def get_tags_for_partition_key(self, partition_key: str) -> Mapping[str, str]:

@property
def time_window_dimension(self) -> PartitionDimensionDefinition:
time_window_dims = [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]
time_window_dims = self._get_time_window_dims()
check.invariant(
len(time_window_dims) == 1, "Expected exactly one time window partitioned dimension"
)
return next(iter(time_window_dims))

def _get_time_window_dims(self) -> List[PartitionDimensionDefinition]:
return [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]

@property
def has_time_window_dimension(self) -> bool:
return bool(self._get_time_window_dims())

@property
def time_window_partitions_def(self) -> TimeWindowPartitionsDefinition:
check.invariant(self.has_time_window_dimension, "Must have time window dimension")
assert isinstance(
self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition
) # appease pyright
return check.inst(self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition)

def time_window_for_partition_key(self, partition_key: str) -> TimeWindow:
if not isinstance(partition_key, MultiPartitionKey):
partition_key = self.get_partition_key_from_str(partition_key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
asset,
)
from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView, AssetSlice
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionKey,
MultiPartitionsDefinition,
)
from dagster._core.definitions.partition import StaticPartitionsDefinition
from dagster._core.definitions.time_window_partitions import DailyPartitionsDefinition, TimeWindow
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.instance import DagsterInstance


Expand Down Expand Up @@ -173,3 +178,80 @@ def up_numbers() -> None: ...
asset_graph_view = AssetGraphView.for_test(defs, instance)
latest_up_slice = asset_graph_view.create_latest_time_window_slice(up_numbers.key)
assert latest_up_slice.compute_partition_keys() == number_keys


def test_multi_dimesional_with_time_partition_latest_time_window() -> None:
# starts at 2020-02-01
daily_partitions_def = DailyPartitionsDefinition(
start_date=pendulum.datetime(2020, 1, 1), end_date=pendulum.datetime(2020, 1, 3)
)

static_partitions_def = StaticPartitionsDefinition(["CA", "NY", "MN"])

multi_partitions_definition = MultiPartitionsDefinition(
{"daily": daily_partitions_def, "static": static_partitions_def}
)

partition_keys = []
jan_2_keys = []
for daily_pk in daily_partitions_def.get_partition_keys():
for static_pk in static_partitions_def.get_partition_keys():
if daily_pk == "2020-01-02":
jan_2_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk}))

partition_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk}))

@asset(partitions_def=multi_partitions_definition)
def multi_dimensional(context: AssetExecutionContext) -> None: ...

defs = Definitions([multi_dimensional])
instance = DagsterInstance.ephemeral()

asset_graph_view_within_partition = AssetGraphView.for_test(
defs, instance, effective_dt=pendulum.datetime(2020, 3, 3)
)

md_slice = asset_graph_view_within_partition.get_asset_slice(multi_dimensional.key)
assert md_slice.compute_partition_keys() == set(partition_keys)
last_tw_slice = asset_graph_view_within_partition.create_latest_time_window_slice(
multi_dimensional.key
)
assert last_tw_slice.compute_partition_keys() == set(jan_2_keys)
assert _tw(last_tw_slice).start == pendulum.datetime(2020, 1, 2)
assert _tw(last_tw_slice).end == pendulum.datetime(2020, 1, 3)

asset_graph_view_in_past = AssetGraphView.for_test(
defs, instance, effective_dt=pendulum.datetime(2019, 3, 3)
)

md_slice_in_past = asset_graph_view_in_past.create_latest_time_window_slice(
multi_dimensional.key
)
assert md_slice_in_past.compute_partition_keys() == set()
assert not md_slice_in_past.time_windows


def test_multi_dimesional_without_time_partition_latest_time_window() -> None:
num_partitions_def = StaticPartitionsDefinition(["1", "2", "3"])
letter_partitions_def = StaticPartitionsDefinition(["A", "B", "C"])

multi_partitions_definition = MultiPartitionsDefinition(
{"num": num_partitions_def, "letter": letter_partitions_def}
)

partition_keys = []
for num_pk in num_partitions_def.get_partition_keys():
for letter_pk in letter_partitions_def.get_partition_keys():
partition_keys.append(MultiPartitionKey({"num": num_pk, "letter": letter_pk}))

@asset(partitions_def=multi_partitions_definition)
def multi_dimensional(context: AssetExecutionContext) -> None: ...

defs = Definitions([multi_dimensional])
instance = DagsterInstance.ephemeral()
asset_graph_view = AssetGraphView.for_test(defs, instance)
md_slice = asset_graph_view.get_asset_slice(multi_dimensional.key)
assert md_slice.compute_partition_keys() == set(partition_keys)
assert asset_graph_view.create_latest_time_window_slice(
multi_dimensional.key
).compute_partition_keys() == set(partition_keys)

0 comments on commit 7e8c5cb

Please sign in to comment.