Skip to content

Commit

Permalink
required_assets_and_checks_by_key
Browse files Browse the repository at this point in the history
  • Loading branch information
johannkm committed Sep 29, 2023
1 parent 1d4fe16 commit 65d8622
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
10 changes: 8 additions & 2 deletions python_modules/dagster/dagster/_core/definitions/asset_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from collections import deque
from collections import defaultdict, deque
from datetime import datetime
from heapq import heapify, heappop, heappush
from typing import (
Expand Down Expand Up @@ -169,7 +169,7 @@ def from_assets(
code_versions_by_key: Dict[AssetKey, Optional[str]] = {}
is_observable_by_key: Dict[AssetKey, bool] = {}
auto_observe_interval_minutes_by_key: Dict[AssetKey, Optional[float]] = {}
required_assets_and_checks_by_key: Dict[SpecKey, AbstractSet[SpecKey]] = {}
required_assets_and_checks_by_key: Dict[SpecKey, AbstractSet[SpecKey]] = defaultdict(set)

for asset in all_assets:
if isinstance(asset, SourceAsset):
Expand All @@ -194,7 +194,13 @@ def from_assets(
all_required_keys = {*asset.check_keys, *asset.keys}
for key in asset.keys:
required_multi_asset_sets_by_key[key] = asset.keys
for key in all_required_keys:
required_assets_and_checks_by_key[key] = all_required_keys
elif len(asset.keys) == 1 and asset.check_specs:
required_keys = {asset.key, *asset.check_keys}
for key in required_keys:
required_assets_and_checks_by_key[key] = required_keys

code_versions_by_key.update(asset.code_versions_by_key)

return InternalAssetGraph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dagster._core.selector.subset_selector import DependencyGraph
from dagster._core.workspace.workspace import IWorkspace

from .asset_graph import AssetGraph
from .asset_graph import AssetGraph, SpecKey
from .backfill_policy import BackfillPolicy
from .events import AssetKey
from .freshness_policy import FreshnessPolicy
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
code_versions_by_key=code_versions_by_key,
is_observable_by_key=is_observable_by_key,
auto_observe_interval_minutes_by_key=auto_observe_interval_minutes_by_key,
required_assets_and_checks_by_key={},
)
self._repo_handles_by_key = repo_handles_by_key
self._materialization_job_names_by_key = job_names_by_key
Expand Down Expand Up @@ -300,3 +301,6 @@ def split_asset_keys_by_repository(
asset_key
)
return list(asset_keys_by_repo.values())

def get_required_asset_and_check_keys(key: SpecKey):
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime
from typing import Callable, List, Optional
from unittest.mock import MagicMock
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.decorators.asset_check_decorator import asset_check

import pendulum
import pytest
Expand Down Expand Up @@ -669,3 +671,62 @@ def unpartitioned3(): ...
},
non_partitioned_asset_keys={AssetKey("unpartitioned2")},
)


def test_required_assets_and_checks_by_key_check_decorator():
@asset
def asset0(): ...

@asset_check(asset=asset0)
def check0(): ...

asset_graph = AssetGraph.from_assets([asset0], asset_checks=[check0])
assert asset_graph.get_required_asset_and_check_keys(asset0.key) == set()
assert asset_graph.get_required_asset_and_check_keys(check0.spec.key) == set()


def test_required_assets_and_checks_by_key_asset_decorator():
foo_check = AssetCheckSpec(name="foo", asset="asset0")
bar_check = AssetCheckSpec(name="bar", asset="asset0")

@asset(check_specs=[foo_check, bar_check])
def asset0(): ...

@asset_check(asset=asset0)
def check0(): ...

asset_graph = AssetGraph.from_assets([asset0], asset_checks=[check0])

grouped_keys = [asset0.key, foo_check.key, bar_check.key]
for key in grouped_keys:
assert asset_graph.get_required_asset_and_check_keys(key) == set(grouped_keys)

assert asset_graph.get_required_asset_and_check_keys(check0.spec.key) == set()


def test_required_assets_and_checks_by_key_multi_asset():
foo_check = AssetCheckSpec(name="foo", asset="asset0")
bar_check = AssetCheckSpec(name="bar", asset="asset1")

@multi_asset(
outs={"asset0": AssetOut(), "asset1": AssetOut()}, check_specs=[foo_check, bar_check]
)
def asset_fn(): ...

biz_check = AssetCheckSpec(name="bar", asset="subsettable_asset0")

@multi_asset(
outs={"subsettable_asset0": AssetOut(), "subsettable_asset1": AssetOut()},
check_specs=[biz_check],
can_subset=True,
)
def subsettable_asset_fn(): ...

asset_graph = AssetGraph.from_assets([asset_fn, subsettable_asset_fn])

grouped_keys = [AssetKey(["asset0"]), AssetKey(["asset1"]), foo_check.key, bar_check.key]
for key in grouped_keys:
assert asset_graph.get_required_asset_and_check_keys(key) == set(grouped_keys)

for key in [AssetKey(["subsettable_asset0"]), AssetKey(["subsettable_asset1"]), biz_check.key]:
assert asset_graph.get_required_asset_and_check_keys(key) == set()

0 comments on commit 65d8622

Please sign in to comment.