Skip to content

Commit

Permalink
Add unsynced status to AssetSlice
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Mar 12, 2024
1 parent c7e13c3 commit 3cf56c4
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand All @@ -12,7 +13,8 @@

from dagster import _check as check
from dagster._core.definitions.asset_subset import AssetSubset, ValidAssetSubset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.data_version import StaleStatus
from dagster._core.definitions.events import AssetKey, AssetKeyPartitionKey
from dagster._core.definitions.multi_dimensional_partitions import (
MultiPartitionKey,
MultiPartitionsDefinition,
Expand Down Expand Up @@ -79,6 +81,22 @@ def _slice_from_subset(asset_graph_view: "AssetGraphView", subset: AssetSubset)
return AssetSlice(asset_graph_view, _AssetSliceCompatibleSubset(valid_subset))


class SyncStatus(Enum):
SYNCED = "SYNCED"
UNSYNCED = "UNSYNCED"

@staticmethod
def from_stale_status(stale_status: StaleStatus) -> "SyncStatus":
"""Convert a StaleStatus to a SyncStatus.
While this appears to lose information, we are redefining stale to unsynced and it is
a binary state, so this reflects that.
One will still be able to know why a partition is unsynced by looking at the causes API.
"""
return SyncStatus.SYNCED if stale_status == StaleStatus.FRESH else SyncStatus.UNSYNCED


class AssetSlice:
"""An asset slice represents a set of partitions for a given asset key. It is
tied to a particular instance of an AssetGraphView, and is read-only.
Expand Down Expand Up @@ -130,6 +148,9 @@ def compute_partition_keys(self) -> AbstractSet[str]:
for akpk in self._compatible_subset.asset_partitions
}

def compute_asset_partitions(self) -> AbstractSet[AssetKeyPartitionKey]:
return self._compatible_subset.asset_partitions

@property
def asset_key(self) -> AssetKey:
return self._compatible_subset.asset_key
Expand Down Expand Up @@ -173,7 +194,6 @@ def time_windows(self) -> Sequence[TimeWindow]:
tw_partitions_def = self._time_window_partitions_def_in_context()
check.inst(tw_partitions_def, TimeWindowPartitionsDefinition, "Must be time windowed.")
assert isinstance(tw_partitions_def, TimeWindowPartitionsDefinition) # appease type checker

if isinstance(self._compatible_subset.subset_value, TimeWindowPartitionsSubset):
return self._compatible_subset.subset_value.included_time_windows
elif isinstance(self._compatible_subset.subset_value, AllPartitionsSubset):
Expand Down Expand Up @@ -207,6 +227,17 @@ def time_windows(self) -> Sequence[TimeWindow]:

check.failed(f"Unsupported partitions_def: {self._partitions_def}")

def only_asset_partitions(
self, asset_partitions: AbstractSet[AssetKeyPartitionKey]
) -> "AssetSlice":
return _slice_from_subset(
self._asset_graph_view,
self._compatible_subset
& AssetSubset.from_asset_partitions_set(
self.asset_key, self._partitions_def, asset_partitions
),
)

def _time_window_partitions_def_in_context(self) -> Optional[TimeWindowPartitionsDefinition]:
pd = self._partitions_def
if isinstance(pd, TimeWindowPartitionsDefinition):
Expand All @@ -219,6 +250,17 @@ def _time_window_partitions_def_in_context(self) -> Optional[TimeWindowPartition
def is_empty(self) -> bool:
return self._compatible_subset.size == 0

@cached_method
def compute_unsynced(self) -> "AssetSlice":
return self._asset_graph_view.compute_unsynced_slice(self)

@cached_method
def compute_sync_statuses(self) -> Mapping[AssetKeyPartitionKey, SyncStatus]:
return self._asset_graph_view.compute_sync_statues(self)

def __repr__(self) -> str:
return f"AssetSlice(subset={self._compatible_subset})"


class AssetGraphView:
"""The Asset Graph View. It is a view of the asset graph from the perspective of a specific
Expand Down Expand Up @@ -484,6 +526,25 @@ def _build_multi_partition_slice(
}
)

def compute_unsynced_slice(self, asset_slice: AssetSlice) -> "AssetSlice":
return asset_slice.only_asset_partitions(
{
ak_pk
for ak_pk, status in self.compute_sync_statues(asset_slice).items()
if status == SyncStatus.UNSYNCED
}
)

def compute_sync_statues(
self, asset_slice: "AssetSlice"
) -> Mapping[AssetKeyPartitionKey, SyncStatus]:
return {
ak_pk: SyncStatus.from_stale_status(
self._stale_resolver.get_status(asset_slice.asset_key, ak_pk.partition_key)
)
for ak_pk in asset_slice.compute_asset_partitions()
}


def _required_tw_partitions_def(
partitions_def: Optional["PartitionsDefinition"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Iterable

from dagster import Definitions, asset
from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView, AssetSlice, SyncStatus
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.events import AssetKey, AssetKeyPartitionKey
from dagster._core.definitions.materialize import materialize
from dagster._core.definitions.partition import StaticPartitionsDefinition
from dagster._core.instance import DagsterInstance


class AssetGraphViewTester:
def __init__(self, defs: Definitions, instance: DagsterInstance) -> None:
self.defs = defs
self.instance = instance
self.asset_graph_view = AssetGraphView.for_test(defs, instance)

def slice(self, asset_key: AssetKey) -> AssetSlice:
return self.asset_graph_view.get_asset_slice(asset_key)

def materialize_partitions(self, assets_def: AssetsDefinition, pks: Iterable[str]) -> None:
for pk in pks:
assert materialize([assets_def], partition_key=pk, instance=self.instance).success
self.next_asset_graph_view()

def next_asset_graph_view(self) -> None:
self.asset_graph_view = AssetGraphView.for_test(self.defs, self.instance)


def test_static_partitioning_unsynced() -> None:
letter_keys = {"a", "b", "c"}
letter_static_partitions_def = StaticPartitionsDefinition(list(letter_keys))

@asset(partitions_def=letter_static_partitions_def)
def up() -> None: ...

@asset(
deps=[up],
partitions_def=letter_static_partitions_def,
)
def down() -> None: ...

defs = Definitions([up, down])
instance = DagsterInstance.ephemeral()
ag_tester = AssetGraphViewTester(defs, instance)

def _synced_dict(asset_key: AssetKey, status: SyncStatus, pks: Iterable[str]) -> dict:
return {AssetKeyPartitionKey(asset_key, pk): status for pk in pks}

# all missing, all unsynced
assert ag_tester.slice(up.key).compute_unsynced().compute_partition_keys() == letter_keys
assert ag_tester.slice(up.key).compute_sync_statuses() == _synced_dict(
up.key, SyncStatus.UNSYNCED, letter_keys
)
assert ag_tester.slice(down.key).compute_unsynced().compute_partition_keys() == letter_keys
assert ag_tester.slice(down.key).compute_sync_statuses() == _synced_dict(
down.key, SyncStatus.UNSYNCED, letter_keys
)

# materialize all of up
ag_tester.materialize_partitions(up, letter_keys)

# all up in sync, all down unsynced
assert ag_tester.slice(up.key).compute_unsynced().compute_partition_keys() == set()
assert ag_tester.slice(up.key).compute_sync_statuses() == _synced_dict(
up.key, SyncStatus.SYNCED, letter_keys
)
assert ag_tester.slice(down.key).compute_unsynced().compute_partition_keys() == letter_keys
assert ag_tester.slice(down.key).compute_unsynced().compute_sync_statuses() == _synced_dict(
down.key, SyncStatus.UNSYNCED, letter_keys
)

# materialize all down. all back in sync
ag_tester.materialize_partitions(down, letter_keys)
assert ag_tester.slice(up.key).compute_unsynced().compute_partition_keys() == set()
assert ag_tester.slice(down.key).compute_unsynced().compute_partition_keys() == set()

def _of_down(partition_key: str) -> AssetKeyPartitionKey:
return AssetKeyPartitionKey(down.key, partition_key)

# materialize only up.b
ag_tester.materialize_partitions(up, ["b"])
assert ag_tester.slice(up.key).compute_unsynced().compute_partition_keys() == set()
assert ag_tester.slice(down.key).compute_unsynced().compute_partition_keys() == {"b"}
assert ag_tester.slice(down.key).compute_sync_statuses() == {
_of_down("a"): SyncStatus.SYNCED,
_of_down("b"): SyncStatus.UNSYNCED,
_of_down("c"): SyncStatus.SYNCED,
}

assert ag_tester.slice(down.key).compute_unsynced().compute_sync_statuses() == {
_of_down("b"): SyncStatus.UNSYNCED
}

# materialize only down.b
# everything in sync
ag_tester.materialize_partitions(down, ["b"])
assert ag_tester.slice(up.key).compute_unsynced().compute_partition_keys() == set()
assert ag_tester.slice(down.key).compute_unsynced().compute_partition_keys() == set()

0 comments on commit 3cf56c4

Please sign in to comment.