Skip to content

Commit

Permalink
PartitionKey is newtype
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Mar 8, 2024
1 parent 4dc2a1b commit b92518d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from datetime import datetime
from typing import TYPE_CHECKING, AbstractSet, Mapping, NamedTuple, Optional

from typing_extensions import TypeAlias
from typing import TYPE_CHECKING, AbstractSet, Mapping, NamedTuple, NewType, Optional

from dagster import _check as check
from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset
Expand Down Expand Up @@ -44,8 +42,12 @@ class TemporalContext(NamedTuple):
class _AssetSliceCompatibleSubset(ValidAssetSubset): ...


PartitionKey: TypeAlias = Optional[str]
AssetPartition: TypeAlias = AssetKeyPartitionKey
PartitionKey = NewType("PartitionKey", str)


class AssetPartitionKey(NamedTuple):
asset_key: AssetKey
partition_key: Optional[PartitionKey]


def _slice_from_subset(asset_graph_view: "AssetGraphView", subset: AssetSubset) -> "AssetSlice":
Expand Down Expand Up @@ -98,7 +100,10 @@ def convert_to_valid_asset_subset(self) -> ValidAssetSubset:
return self._compatible_subset

def compute_partition_keys(self) -> AbstractSet[PartitionKey]:
return {ap.partition_key for ap in self._compatible_subset.asset_partitions}
return {
PartitionKey(check.not_none(ap.partition_key, "Must have named partitions"))
for ap in self._compatible_subset.asset_partitions
}

@property
def asset_key(self) -> AssetKey:
Expand Down Expand Up @@ -132,16 +137,26 @@ def compute_parent_slices(self) -> Mapping[AssetKey, "AssetSlice"]:
def compute_child_slices(self) -> Mapping[AssetKey, "AssetSlice"]:
return {ak: self.compute_child_slice(ak) for ak in self.child_keys}

def only_partition_keys(self, partition_keys: AbstractSet[PartitionKey]) -> "AssetSlice":
def only_partition_keys(self, partition_keys: AbstractSet[str]) -> "AssetSlice":
"""Return a new AssetSlice with only the given partition keys if they are in the slice."""
partitions_def = check.not_none(self._partitions_def, "Must have partitions def")
for partition_key in partition_keys:
if not partitions_def.has_partition_key(partition_key):
check.failed(
f"Partition key {partition_key} not in partitions def {self._partitions_def}"
)

return _slice_from_subset(
self._asset_graph_view,
AssetSubset.from_asset_partitions_set(
self._compatible_subset
& AssetSubset.from_asset_partitions_set(
self.asset_key,
self._partitions_def,
{AssetPartition(self.asset_key, partition_key) for partition_key in partition_keys},
)
& self._compatible_subset,
partitions_def,
{
AssetKeyPartitionKey(self.asset_key, partition_key)
for partition_key in partition_keys
},
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ def up_numbers() -> None: ...
{"1", "2"}
).compute_partition_keys() == {"1", "2"}

assert (
asset_graph_view_t0.get_asset_slice(up_numbers.key)
.only_partition_keys({"4"})
.compute_partition_keys()
== set()
)
assert asset_graph_view_t0.get_asset_slice(up_numbers.key).only_partition_keys(
{"3"}
).compute_partition_keys() == set(["3"])

0 comments on commit b92518d

Please sign in to comment.