Skip to content

Commit

Permalink
Support multi-partitioning in AssetSlice
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Mar 8, 2024
1 parent 0c64516 commit 02f35a4
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from dagster import _check as check
from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset
from dagster._core.definitions.events import AssetKey, AssetKeyPartitionKey
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionKey,
MultiPartitionsDefinition,
PartitionDimensionDefinition,
)
from dagster._core.definitions.partition import StaticPartitionsDefinition
from dagster._core.definitions.time_window_partitions import (
TimeWindow,
Expand Down Expand Up @@ -153,6 +158,30 @@ def only_partition_keys(self, partition_keys: AbstractSet[PartitionKey]) -> "Ass
& self._compatible_subset,
)

class MultiDimInfo(NamedTuple):
tw_dim: PartitionDimensionDefinition
secondary_dim: PartitionDimensionDefinition

@property
def tw_partition_def(self) -> TimeWindowPartitionsDefinition:
check.inst(self.tw_dim.partitions_def, TimeWindowPartitionsDefinition)
assert isinstance(
self.tw_dim.partitions_def, TimeWindowPartitionsDefinition
) # appease pyright
return self.tw_dim.partitions_def

@property
def secondary_partition_def(self) -> "PartitionsDefinition":
return self.secondary_dim.partitions_def

def _get_multi_dim_info(self) -> "MultiDimInfo":
check.inst(self._partitions_def, MultiPartitionsDefinition)
assert isinstance(self._partitions_def, MultiPartitionsDefinition) # appease pyright
return self.MultiDimInfo(
tw_dim=self._partitions_def.time_window_dimension,
secondary_dim=self._partitions_def.secondary_dimension,
)

@cached_property
def latest_time_window_slice(self) -> "AssetSlice":
"""Returns the latest time window for the asset slice.
Expand Down Expand Up @@ -183,9 +212,53 @@ def latest_time_window_slice(self) -> "AssetSlice":
else self._asset_graph_view.create_empty_slice(self.asset_key)
)

# Need to handle dynamic and multi-dimensional partitioning
if isinstance(self._partitions_def, MultiPartitionsDefinition):
if not self._partitions_def.has_time_window_dimension:
return self

multi_dim_info = self._get_multi_dim_info()
last_tw = multi_dim_info.tw_partition_def.get_last_partition_window(
self._asset_graph_view.effective_dt
)
return (
self._build_multi_partition_slice(multi_dim_info, last_tw)
if last_tw
else self._asset_graph_view.create_empty_slice(self.asset_key)
)

# Need to handle dynamic partitioning
check.failed(f"Unsupported partitions_def: {self._partitions_def}")

def _build_multi_partition_slice(
self, multi_dim_info: MultiDimInfo, last_tw: TimeWindow
) -> "AssetSlice":
# Note: Potential perf improvement here. There is no way to encode a cartesian product
# in the underlying PartitionsSet. We could add a specialized PartitionsSubset
# subclass that itself composed two PartitionsSubset to avoid materializing the entire
# partitions range.
return self._asset_graph_view.get_asset_slice(self.asset_key).only_partition_keys(
{
MultiPartitionKey(
{
multi_dim_info.tw_dim.name: tw_pk,
multi_dim_info.secondary_dim.name: secondary_pk,
}
)
for tw_pk in multi_dim_info.tw_partition_def.get_partition_keys_in_time_window(
last_tw
)
for secondary_pk in multi_dim_info.secondary_partition_def.get_partition_keys()
}
)

def _time_window_partitions_def_in_context(self) -> Optional[TimeWindowPartitionsDefinition]:
pd = self._partitions_def
if isinstance(pd, TimeWindowPartitionsDefinition):
return pd
if isinstance(pd, MultiPartitionsDefinition):
return pd.time_window_partitions_def if pd.has_time_window_dimension else None
return None

@cached_property
def latest_time_window(self) -> TimeWindow:
"""If the underlying asset is time-window partitioned, this will return the latest complete
Expand All @@ -196,25 +269,18 @@ def latest_time_window(self) -> TimeWindow:
If the underlying asset is unpartitioned or static partitioned and it is not empty,
this will return a time window from the beginning of time to the effective date. If
it is empty it will return the empty time window.
TODO: add language for multi-dimensional partitioning when we support it
TODO: add language for dynamic partitioning when we support it
"""
if isinstance(self._partitions_def, TimeWindowPartitionsDefinition):
tw = self._partitions_def.get_last_partition_window(self._asset_graph_view.effective_dt)
return tw if tw else TimeWindow.empty()
tw_partitions_def = self._time_window_partitions_def_in_context()

if self._partitions_def is None or isinstance(
self._partitions_def, StaticPartitionsDefinition
):
if not tw_partitions_def:
return (
TimeWindow.empty()
if self.is_empty
else TimeWindow(datetime.min, self._asset_graph_view.effective_dt)
)

# Need to handle dynamic and multi-dimensional partitioning
check.failed(f"Unsupported partitions_def: {self._partitions_def}")
tw = tw_partitions_def.get_last_partition_window(self._asset_graph_view.effective_dt)
return tw if tw else TimeWindow.empty()

@property
def is_empty(self) -> bool:
Expand Down Expand Up @@ -356,7 +422,7 @@ def create_from_time_window(self, asset_key: AssetKey, time_window: TimeWindow)
TimeWindowPartitionsDefinition,
"Must be a time-windowed partition definition",
)
assert isinstance(partitions_def, TimeWindowPartitionsDefinition) # appease type checker
assert isinstance(partitions_def, TimeWindowPartitionsDefinition) # appease pyright
return _slice_from_subset(
self,
AssetSubset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,7 @@ def _get_primary_and_secondary_dimension(
# the selection of primary/secondary dimension, will need to also update the
# serialization of MultiPartitionsSubsets

time_dimensions = [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]
time_dimensions = self._get_time_window_dims()
if len(time_dimensions) == 1:
primary_dimension, secondary_dimension = (
time_dimensions[0],
Expand Down Expand Up @@ -429,15 +425,32 @@ def get_tags_for_partition_key(self, partition_key: str) -> Mapping[str, str]:

@property
def time_window_dimension(self) -> PartitionDimensionDefinition:
time_window_dims = [
check.invariant(self.has_time_window_dimension, "Must have time window dimension")
check.inst(
self.primary_dimension.partitions_def,
TimeWindowPartitionsDefinition,
"Sanity check that assumption that primary is time-windowed if it exists",
)
return self.primary_dimension

def _get_time_window_dims(self) -> List[PartitionDimensionDefinition]:
return [
dim
for dim in self.partitions_defs
if isinstance(dim.partitions_def, TimeWindowPartitionsDefinition)
]
check.invariant(
len(time_window_dims) == 1, "Expected exactly one time window partitioned dimension"
)
return next(iter(time_window_dims))

@property
def has_time_window_dimension(self) -> bool:
return bool(self._get_time_window_dims())

@property
def time_window_partitions_def(self) -> TimeWindowPartitionsDefinition:
check.invariant(self.has_time_window_dimension, "Must have time window dimension")
assert isinstance(
self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition
) # appease pyright
return check.inst(self.primary_dimension.partitions_def, TimeWindowPartitionsDefinition)

def time_window_for_partition_key(self, partition_key: str) -> TimeWindow:
if not isinstance(partition_key, MultiPartitionKey):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import pendulum
from dagster import Definitions, asset
from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionKey,
MultiPartitionsDefinition,
)
from dagster._core.definitions.partition import StaticPartitionsDefinition
from dagster._core.definitions.time_window_partitions import DailyPartitionsDefinition
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.instance import DagsterInstance


Expand Down Expand Up @@ -165,3 +170,80 @@ def up_numbers() -> None: ...
assert up_slice.latest_time_window_slice.compute_partition_keys() == number_keys
assert not up_slice.is_empty
assert asset_graph_view.create_empty_slice(up_numbers.key).latest_time_window_slice.is_empty


def test_multi_dimesional_with_time_partition_latest_time_window() -> None:
# starts at 2020-02-01
daily_partitions_def = DailyPartitionsDefinition(
start_date=pendulum.datetime(2020, 1, 1), end_date=pendulum.datetime(2020, 1, 3)
)

static_partitions_def = StaticPartitionsDefinition(["CA", "NY", "MN"])

multi_partitions_definition = MultiPartitionsDefinition(
{"daily": daily_partitions_def, "static": static_partitions_def}
)

partition_keys = []
jan_2_keys = []
for daily_pk in daily_partitions_def.get_partition_keys():
for static_pk in static_partitions_def.get_partition_keys():
if daily_pk == "2020-01-02":
jan_2_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk}))

partition_keys.append(MultiPartitionKey({"daily": daily_pk, "static": static_pk}))

@asset(partitions_def=multi_partitions_definition)
def multi_dimensional(context: AssetExecutionContext) -> None: ...

defs = Definitions([multi_dimensional])
instance = DagsterInstance.ephemeral()

asset_graph_view_within_partition = AssetGraphView.for_test(
defs, instance, effective_dt=pendulum.datetime(2020, 3, 3)
)

md_slice = asset_graph_view_within_partition.get_asset_slice(multi_dimensional.key)
assert md_slice.compute_partition_keys() == set(partition_keys)
assert md_slice.latest_time_window_slice.compute_partition_keys() == set(jan_2_keys)
assert md_slice.latest_time_window.start == pendulum.datetime(2020, 1, 2)
assert md_slice.latest_time_window.end == pendulum.datetime(2020, 1, 3)

asset_graph_view_in_past = AssetGraphView.for_test(
defs, instance, effective_dt=pendulum.datetime(2019, 3, 3)
)

md_slice_in_past = asset_graph_view_in_past.get_asset_slice(multi_dimensional.key)
assert md_slice_in_past.compute_partition_keys() == set()
assert md_slice_in_past.latest_time_window_slice.compute_partition_keys() == set()
assert md_slice_in_past.latest_time_window.is_empty


def test_multi_dimesional_without_time_partition_latest_time_window() -> None:
num_partitions_def = StaticPartitionsDefinition(["1", "2", "3"])
letter_partitions_def = StaticPartitionsDefinition(["A", "B", "C"])

multi_partitions_definition = MultiPartitionsDefinition(
{"num": num_partitions_def, "letter": letter_partitions_def}
)

partition_keys = []
for num_pk in num_partitions_def.get_partition_keys():
for letter_pk in letter_partitions_def.get_partition_keys():
partition_keys.append(MultiPartitionKey({"num": num_pk, "letter": letter_pk}))

@asset(partitions_def=multi_partitions_definition)
def multi_dimensional(context: AssetExecutionContext) -> None: ...

defs = Definitions([multi_dimensional])
instance = DagsterInstance.ephemeral()
asset_graph_view = AssetGraphView.for_test(defs, instance)
md_slice = asset_graph_view.get_asset_slice(multi_dimensional.key)
assert md_slice.compute_partition_keys() == set(partition_keys)
assert md_slice.latest_time_window_slice.compute_partition_keys() == set(partition_keys)
assert not md_slice.latest_time_window.is_empty

assert asset_graph_view.create_empty_slice(
multi_dimensional.key
).latest_time_window_slice.is_empty
assert asset_graph_view.create_empty_slice(multi_dimensional.key).latest_time_window.is_empty

0 comments on commit 02f35a4

Please sign in to comment.