Skip to content

Commit

Permalink
cp
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Mar 6, 2024
1 parent 0f8a039 commit d83cda2
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from datetime import datetime
from typing import TYPE_CHECKING, AbstractSet, Mapping, NamedTuple, Optional

from typing_extensions import TypeAlias

from dagster import _check as check
from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.events import AssetKey, AssetKeyPartitionKey
from dagster._utils.cached_method import CACHED_METHOD_FIELD_SUFFIX, cached_method

if TYPE_CHECKING:
Expand All @@ -23,12 +25,20 @@ class TemporalContext(NamedTuple):
class _AssetSliceCompatibleSubset(ValidAssetSubset): ...


PartitionKey = Optional[str]
PartitionKey: TypeAlias = Optional[str]
AssetPartition: TypeAlias = AssetKeyPartitionKey


def _slice_from_subset(asset_graph_view: "AssetGraphView", subset: AssetSubset) -> "AssetSlice":
valid_subset = subset.as_valid(
asset_graph_view.asset_graph.get_partitions_def(subset.asset_key)
)
return AssetSlice(asset_graph_view, _AssetSliceCompatibleSubset(*valid_subset))


class AssetSlice:
# create slots for every cached method
CACHED_METHOD_FIELDS = ["parent_slices", "child_slices"]
CACHED_METHOD_FIELDS = ["get_parent_slice", "get_child_slice", "parent_slices", "child_slices"]
# Using slots instead of NamedTuple/Dataclass so we can encapsulate and so we
# can create slots for cached methods
__slots__ = ["_asset_graph_view", "_compatible_subset"] + [
Expand All @@ -47,29 +57,55 @@ def convert_to_valid_asset_subset(self) -> ValidAssetSubset:
def materialize_partition_keys(self) -> AbstractSet[PartitionKey]:
return {ap.partition_key for ap in self._compatible_subset.asset_partitions}

@property
def asset_key(self) -> AssetKey:
return self._compatible_subset.asset_key

@property
def parent_keys(self) -> AbstractSet[AssetKey]:
return self._asset_graph_view.asset_graph.get_parents(self.asset_key)

@property
def _partitions_def(self) -> Optional["PartitionsDefinition"]:
return self._asset_graph_view.asset_graph.get_partitions_def(self.asset_key)

@cached_method
def get_parent_slice(self, parent_asset_key: AssetKey) -> "AssetSlice":
return self._asset_graph_view.get_parent_asset_slice(parent_asset_key, self)

@cached_method
def get_child_slice(self, child_asset_key: AssetKey) -> "AssetSlice":
return self._asset_graph_view.get_child_asset_slice(child_asset_key, self)

@property
@cached_method
def parent_slices(self) -> Mapping[AssetKey, "AssetSlice"]:
parent_slices_by_asset_key = {}
for parent_asset_key in self._asset_graph_view._asset_graph.get_parents( # noqa: SLF001
for parent_asset_key in self._asset_graph_view.asset_graph.get_parents(
self._compatible_subset.asset_key
):
parent_slices_by_asset_key[parent_asset_key] = (
self._asset_graph_view.get_parent_asset_slice(parent_asset_key, self)
)
parent_slices_by_asset_key[parent_asset_key] = self.get_parent_slice(parent_asset_key)
return parent_slices_by_asset_key

@property
@cached_method
def child_slices(self) -> Mapping[AssetKey, "AssetSlice"]:
child_slices_by_asset_key = {}
for parent_asset_key in self._asset_graph_view._asset_graph.get_parents( # noqa: SLF001
self._compatible_subset.asset_key
):
child_slices_by_asset_key[parent_asset_key] = (
self._asset_graph_view.get_parent_asset_slice(parent_asset_key, self)
)
return child_slices_by_asset_key
return {
parent_asset_key: self.get_child_slice(parent_asset_key)
for parent_asset_key in self._asset_graph_view.asset_graph.get_children(self.asset_key)
}

def of_materialized_partition_keys(
self, partition_keys: AbstractSet[PartitionKey]
) -> "AssetSlice":
return _slice_from_subset(
self._asset_graph_view,
AssetSubset.from_asset_partitions_set(
self.asset_key,
self._partitions_def,
{AssetPartition(self.asset_key, partition_key) for partition_key in partition_keys},
),
)


class AssetGraphView:
Expand Down Expand Up @@ -139,51 +175,49 @@ def last_event_id(self) -> Optional[int]:
return self._temporal_context.last_event_id

@property
def _asset_graph(self) -> "AssetGraph":
def asset_graph(self) -> "AssetGraph":
return self._stale_resolver.asset_graph

@property
def _queryer(self) -> "CachingInstanceQueryer":
return self._stale_resolver.instance_queryer

def _get_partitions_def(self, asset_key: "AssetKey") -> Optional["PartitionsDefinition"]:
return self._asset_graph.get_partitions_def(asset_key)
return self.asset_graph.get_partitions_def(asset_key)

def get_asset_slice(self, asset_key: "AssetKey") -> "AssetSlice":
partitions_def = self._get_partitions_def(asset_key)
return self._slice_from_subset(
return _slice_from_subset(
self,
AssetSubset.all(
asset_key=asset_key,
partitions_def=partitions_def,
partitions_def=self._get_partitions_def(asset_key),
dynamic_partitions_store=self._queryer,
current_time=self.effective_dt,
)
),
)

def get_parent_asset_slice(
self, parent_asset_key: AssetKey, asset_slice: AssetSlice
) -> AssetSlice:
return self._slice_from_subset(
self._asset_graph.get_parent_asset_subset(
return _slice_from_subset(
self,
self.asset_graph.get_parent_asset_subset(
dynamic_partitions_store=self._queryer,
parent_asset_key=parent_asset_key,
child_asset_subset=asset_slice.convert_to_valid_asset_subset(),
current_time=self.effective_dt,
)
),
)

def get_child_asset_slice(
self, child_asset_key: "AssetKey", asset_slice: AssetSlice
) -> "AssetSlice":
return self._slice_from_subset(
self._asset_graph.get_child_asset_subset(
return _slice_from_subset(
self,
self.asset_graph.get_child_asset_subset(
dynamic_partitions_store=self._queryer,
child_asset_key=child_asset_key,
current_time=self.effective_dt,
parent_asset_subset=asset_slice.convert_to_valid_asset_subset(),
)
),
)

def _slice_from_subset(self: "AssetGraphView", subset: AssetSubset) -> AssetSlice:
valid_subset = subset.as_valid(self._get_partitions_def(subset.asset_key))
return AssetSlice(self, _AssetSliceCompatibleSubset(*valid_subset))
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def an_asset() -> None: ...
# hiding stale resolver deliberately but want to test instance object identity
assert asset_graph_view_t0._stale_resolver.instance_queryer.instance is instance # noqa: SLF001

# also hiding asset graph deliberately but want to test asset keys
assert asset_graph_view_t0._asset_graph.all_asset_keys == {an_asset.key} # noqa: SLF001
assert asset_graph_view_t0.asset_graph.all_asset_keys == {an_asset.key}


def test_slice_traversal_static_partitions() -> None:
Expand Down Expand Up @@ -73,3 +72,13 @@ def down_letters() -> None: ...
"2",
"3",
}

# subset of up to subset of down
assert up_slice.of_materialized_partition_keys({"2"}).child_slices[
down_letters.key
].materialize_partition_keys() == {"b"}

# subset of down to subset of up
assert down_slice.of_materialized_partition_keys({"b"}).parent_slices[
up_numbers.key
].materialize_partition_keys() == {"2"}

0 comments on commit d83cda2

Please sign in to comment.