Skip to content

Commit

Permalink
Make asset check coercion explicit in with_attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Dec 19, 2024
1 parent 26d808d commit 7b0e692
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ def get_python_identifier(self) -> str:
@property
def key(self) -> AssetCheckKey:
return AssetCheckKey(self.asset_key, self.name)

def replace_key(self, key: AssetCheckKey) -> "AssetCheckSpec":
return self._replace(asset_key=key.asset_key, name=key.name)
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def normalize_assets(
# AssetKey not subject to any further manipulation.
resolved_deps = ResolvedAssetDependencies(assets_defs, [])

input_asset_key_replacements = [
asset_key_replacements = [
{
raw_key: normalized_key
for input_name, raw_key in ad.keys_by_input_name.items()
Expand All @@ -218,8 +218,8 @@ def normalize_assets(

# Only update the assets defs if we're actually replacing input asset keys
assets_defs = [
ad.with_attributes(input_asset_key_replacements=reps) if reps else ad
for ad, reps in zip(assets_defs, input_asset_key_replacements)
ad.with_attributes(asset_key_replacements=reps) if reps else ad
for ad, reps in zip(assets_defs, asset_key_replacements)
]

# Create unexecutable external assets definitions for any referenced keys for which no
Expand Down
6 changes: 6 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/asset_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ def from_db_string(db_string: str) -> Optional["AssetCheckKey"]:
def to_db_string(self) -> str:
return seven.json.dumps({"asset_key": self.asset_key.to_string(), "check_name": self.name})

def with_asset_key_prefix(self, prefix: CoercibleToAssetKeyPrefix) -> "AssetCheckKey":
return AssetCheckKey(self.asset_key.with_prefix(prefix), self.name)

def replace_asset_key(self, asset_key: AssetKey) -> "AssetCheckKey":
return AssetCheckKey(asset_key, self.name)


EntityKey = Union[AssetKey, AssetCheckKey]
T_EntityKey = TypeVar("T_EntityKey", AssetKey, AssetCheckKey, EntityKey)
Expand Down
32 changes: 13 additions & 19 deletions python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,8 +1208,7 @@ def coerce_to_checks_def(self) -> "AssetChecksDefinition":
def with_attributes(
self,
*,
output_asset_key_replacements: Mapping[AssetKey, AssetKey] = {},
input_asset_key_replacements: Mapping[AssetKey, AssetKey] = {},
asset_key_replacements: Mapping[AssetKey, AssetKey] = {},
group_names_by_key: Mapping[AssetKey, str] = {},
tags_by_key: Mapping[AssetKey, Mapping[str, str]] = {},
freshness_policy: Optional[
Expand Down Expand Up @@ -1254,16 +1253,13 @@ def update_replace_dict_and_conflicts(
default_value=DEFAULT_GROUP_NAME,
)

if key in output_asset_key_replacements:
replace_dict["key"] = output_asset_key_replacements[key]
if key in asset_key_replacements:
replace_dict["key"] = asset_key_replacements[key]

if input_asset_key_replacements or output_asset_key_replacements:
if asset_key_replacements:
new_deps = []
for dep in spec.deps:
replacement_key = input_asset_key_replacements.get(
dep.asset_key,
output_asset_key_replacements.get(dep.asset_key),
)
replacement_key = asset_key_replacements.get(dep.asset_key, dep.asset_key)
if replacement_key is not None:
new_deps.append(dep._replace(asset_key=replacement_key))
else:
Expand All @@ -1280,33 +1276,31 @@ def update_replace_dict_and_conflicts(
)

check_specs_by_output_name = {
output_name: check_spec._replace(
asset_key=output_asset_key_replacements.get(
check_spec.asset_key, check_spec.asset_key
output_name: check_spec.replace_key(
key=check_spec.key.replace_asset_key(
asset_key_replacements.get(check_spec.asset_key, check_spec.asset_key)
)
)
for output_name, check_spec in self.node_check_specs_by_output_name.items()
}

selected_asset_check_keys = {
check_key._replace(
asset_key=output_asset_key_replacements.get(
check_key.asset_key, check_key.asset_key
)
check_key.replace_asset_key(
asset_key_replacements.get(check_key.asset_key, check_key.asset_key)
)
for check_key in self.check_keys
}

replaced_attributes = dict(
keys_by_input_name={
input_name: input_asset_key_replacements.get(key, key)
input_name: asset_key_replacements.get(key, key)
for input_name, key in self.node_keys_by_input_name.items()
},
keys_by_output_name={
output_name: output_asset_key_replacements.get(key, key)
output_name: asset_key_replacements.get(key, key)
for output_name, key in self.node_keys_by_output_name.items()
},
selected_asset_keys={output_asset_key_replacements.get(key, key) for key in self.keys},
selected_asset_keys={asset_key_replacements.get(key, key) for key in self.keys},
backfill_policy=backfill_policy if backfill_policy else self.backfill_policy,
is_subset=self.is_subset,
check_specs_by_output_name=check_specs_by_output_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,15 @@ def with_resources(

def with_attributes(
self,
output_asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
input_asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
group_names_by_key: Optional[Mapping[AssetKey, str]] = None,
freshness_policy: Optional[
Union[FreshnessPolicy, Mapping[AssetKey, FreshnessPolicy]]
] = None,
) -> "CacheableAssetsDefinition":
return PrefixOrGroupWrappedCacheableAssetsDefinition(
self,
output_asset_key_replacements=output_asset_key_replacements,
input_asset_key_replacements=input_asset_key_replacements,
asset_key_replacements=asset_key_replacements,
group_names_by_key=group_names_by_key,
freshness_policy=freshness_policy,
)
Expand Down Expand Up @@ -247,8 +245,7 @@ class PrefixOrGroupWrappedCacheableAssetsDefinition(WrappedCacheableAssetsDefini
def __init__(
self,
wrapped: CacheableAssetsDefinition,
output_asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
input_asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
group_names_by_key: Optional[Mapping[AssetKey, str]] = None,
group_name_for_all_assets: Optional[str] = None,
prefix_for_all_assets: Optional[List[str]] = None,
Expand All @@ -260,8 +257,7 @@ def __init__(
] = None,
backfill_policy: Optional[BackfillPolicy] = None,
):
self._output_asset_key_replacements = output_asset_key_replacements or {}
self._input_asset_key_replacements = input_asset_key_replacements or {}
self._asset_key_replacements = asset_key_replacements or {}
self._group_names_by_key = group_names_by_key or {}
self._group_name_for_all_assets = group_name_for_all_assets
self._prefix_for_all_assets = prefix_for_all_assets
Expand All @@ -274,12 +270,8 @@ def __init__(
"Cannot set both group_name_for_all_assets and group_names_by_key",
)
check.invariant(
not (
prefix_for_all_assets
and (output_asset_key_replacements or input_asset_key_replacements)
),
"Cannot set both prefix_for_all_assets and output_asset_key_replacements or"
" input_asset_key_replacements",
not (prefix_for_all_assets and (asset_key_replacements)),
"Cannot set both prefix_for_all_assets and asset_key_replacements",
)

super().__init__(
Expand All @@ -290,22 +282,10 @@ def __init__(
def _get_hash(self) -> str:
"""Generate a stable hash of the various prefix/group mappings."""
contents = hashlib.sha1()
if self._output_asset_key_replacements:
if self._asset_key_replacements:
contents.update(
_map_to_hashable(
{
tuple(k.path): tuple(v.path)
for k, v in self._output_asset_key_replacements.items()
}
)
)
if self._input_asset_key_replacements:
contents.update(
_map_to_hashable(
{
tuple(k.path): tuple(v.path)
for k, v in self._input_asset_key_replacements.items()
}
{tuple(k.path): tuple(v.path) for k, v in self._asset_key_replacements.items()}
)
)
if self._group_names_by_key:
Expand All @@ -330,33 +310,13 @@ def transformed_assets_def(self, assets_def: AssetsDefinition) -> AssetsDefiniti
if self._group_name_for_all_assets
else self._group_names_by_key
)
output_asset_key_replacements = (
{
k: AssetKey(
path=(
self._prefix_for_all_assets + list(k.path)
if self._prefix_for_all_assets
else k.path
)
)
for k in assets_def.keys
}
if self._prefix_for_all_assets
else self._output_asset_key_replacements
)
input_asset_key_replacements = (
asset_key_replacements = (
{
k: AssetKey(
path=(
self._prefix_for_all_assets + list(k.path)
if self._prefix_for_all_assets
else k.path
)
)
for k in assets_def.dependency_keys
k: k.with_prefix(self._prefix_for_all_assets) if self._prefix_for_all_assets else k
for k in assets_def.keys | set(assets_def.dependency_keys)
}
if self._prefix_for_all_assets
else self._input_asset_key_replacements
else self._asset_key_replacements
)
if isinstance(self._auto_materialize_policy, dict):
automation_condition = {
Expand All @@ -367,8 +327,7 @@ def transformed_assets_def(self, assets_def: AssetsDefinition) -> AssetsDefiniti
else:
automation_condition = None
return assets_def.with_attributes(
output_asset_key_replacements=output_asset_key_replacements,
input_asset_key_replacements=input_asset_key_replacements,
asset_key_replacements=asset_key_replacements,
group_names_by_key=group_names_by_key,
freshness_policy=self._freshness_policy,
automation_condition=automation_condition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dagster._check as check
from dagster._core.definitions.asset_checks import AssetChecksDefinition, has_only_asset_checks
from dagster._core.definitions.asset_key import (
AssetCheckKey,
AssetKey,
CoercibleToAssetKeyPrefix,
check_opt_coercible_to_asset_key_prefix_param,
Expand Down Expand Up @@ -429,6 +430,7 @@ def find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]:
def replace_keys_in_asset(
asset: Union[AssetsDefinition, AssetSpec, SourceAsset],
key_replacements: Mapping[AssetKey, AssetKey],
check_key_replacements: Mapping[AssetCheckKey, AssetCheckKey],
) -> Union[AssetsDefinition, AssetSpec, SourceAsset]:
if isinstance(asset, SourceAsset):
return asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
Expand All @@ -438,13 +440,7 @@ def replace_keys_in_asset(
)
else:
updated_object = asset.with_attributes(
output_asset_key_replacements={
key: key_replacements.get(key, key)
for key in key_iterator(asset, included_targeted_keys=True)
},
input_asset_key_replacements={
key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values()
},
asset_key_replacements=key_replacements,
)
return (
updated_object.coerce_to_checks_def()
Expand All @@ -469,6 +465,10 @@ def assets_defs_and_specs(self) -> Sequence[Union[AssetsDefinition, AssetSpec]]:
or isinstance(dagster_object, AssetSpec)
]

@cached_property
def assets_defs(self) -> Sequence[AssetsDefinition]:
return [asset for asset in self.loaded_objects if isinstance(asset, AssetsDefinition)]

@cached_property
def checks_defs(self) -> Sequence[AssetChecksDefinition]:
return [
Expand Down Expand Up @@ -511,12 +511,21 @@ def assets_with_loadable_prefix(
for asset_object in self.assets_defs_specs_and_checks_defs
for key in key_iterator(asset_object, included_targeted_keys=True)
}
all_check_keys = {
check_key for asset_object in self.assets_defs for check_key in asset_object.check_keys
}

key_replacements = {key: key.with_prefix(key_prefix) for key in all_asset_keys}
check_key_replacements = {
check_key: check_key.with_asset_key_prefix(key_prefix) for check_key in all_check_keys
}
for asset_object in self.loaded_objects:
if isinstance(asset_object, CacheableAssetsDefinition):
result_list.append(asset_object.with_prefix_for_all(key_prefix))
elif isinstance(asset_object, AssetsDefinition):
result_list.append(replace_keys_in_asset(asset_object, key_replacements))
result_list.append(
replace_keys_in_asset(asset_object, key_replacements, check_key_replacements)
)
else:
# We don't replace the key for SourceAssets.
result_list.append(asset_object)
Expand All @@ -532,7 +541,9 @@ def assets_with_source_prefix(
}
for asset_object in self.loaded_objects:
if isinstance(asset_object, KeyScopedAssetObjects):
result_list.append(replace_keys_in_asset(asset_object, key_replacements))
result_list.append(
replace_keys_in_asset(asset_object, key_replacements, check_key_replacements={})
)
else:
result_list.append(asset_object)
return ResolvedAssetObjectList(result_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,9 @@ def test_define_selection_job(job_selection, expected_assets, use_multi, prefixe
for prefix in reversed(prefixes or []):
prefixed_assets = [
assets_def.with_attributes(
input_asset_key_replacements={
key: key.with_prefix(prefix) for key in assets_def.keys_by_input_name.values()
},
output_asset_key_replacements={
key: key.with_prefix(prefix) for key in assets_def.keys
asset_key_replacements={
key: key.with_prefix(prefix)
for key in set(assets_def.keys_by_input_name.values()) | set(assets_def.keys)
},
)
for assets_def in prefixed_assets
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import List, Sequence, Union, cast

import dagster._check as check
import pytest
Expand Down Expand Up @@ -151,7 +151,9 @@ def test_resolve_wrong_data():
recon_repo.get_definition()


def define_uncacheable_and_resource_dependent_cacheable_assets():
def define_uncacheable_and_resource_dependent_cacheable_assets() -> (
Sequence[Union[CacheableAssetsDefinition, AssetsDefinition]]
):
class ResourceDependentCacheableAsset(CacheableAssetsDefinition):
def __init__(self):
super().__init__("res_midstream")
Expand Down Expand Up @@ -295,7 +297,7 @@ def _op():
)


def test_multiple_wrapped_cached_assets():
def test_multiple_wrapped_cached_assets() -> None:
"""Test that multiple wrappers (with_attributes, with_resources) work properly on cacheable assets."""

@resource
Expand All @@ -304,9 +306,7 @@ def foo_resource():

my_cacheable_assets_with_group_and_asset = [
x.with_attributes(
output_asset_key_replacements={
AssetKey("res_downstream"): AssetKey("res_downstream_too")
}
asset_key_replacements={AssetKey("res_downstream"): AssetKey("res_downstream_too")}
)
for x in with_resources(
[
Expand All @@ -333,13 +333,19 @@ def resource_dependent_repo_with_resources():
assert isinstance(repo.get_job("all_asset_job"), JobDefinition)

my_cool_group_sel = AssetSelection.groups("my_cool_group")
cacheable_resource_asset = cast(
CacheableAssetsDefinition, my_cacheable_assets_with_group_and_asset[0]
)
resolved_defs = list(
cacheable_resource_asset.build_definitions(
cacheable_resource_asset.compute_cacheable_data()
)
)
assert (
len(
my_cool_group_sel.resolve(
my_cacheable_assets_with_group_and_asset[0].build_definitions(
my_cacheable_assets_with_group_and_asset[0].compute_cacheable_data()
)
+ my_cacheable_assets_with_group_and_asset[1:]
resolved_defs
+ cast(List[AssetsDefinition], my_cacheable_assets_with_group_and_asset[1:])
)
)
== 1
Expand Down

0 comments on commit 7b0e692

Please sign in to comment.