Skip to content

Commit

Permalink
[module-loaders] Make asset check coercion explicit in with_attributes (
Browse files Browse the repository at this point in the history
#26529)

## Summary & Motivation
This was one of the most confusing parts of the original code path; that
asset check key changing was happening _implicitly_ within
with_attributes when you changed output_asset_keys.

Instead, provide a top-level, explicit way to remap asset check keys. I
think the old behavior was highly mysterious. I found myself questioning
what the heck was happening. This is much more; "what it says on the
tin."

## How I Tested These Changes
Existing tests.
  • Loading branch information
dpeng817 committed Dec 19, 2024
1 parent eefb904 commit 5ddc747
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ def get_python_identifier(self) -> str:
@property
def key(self) -> AssetCheckKey:
return AssetCheckKey(self.asset_key, self.name)

def replace_key(self, key: AssetCheckKey) -> "AssetCheckSpec":
return self._replace(asset_key=key.asset_key, name=key.name)
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def normalize_assets(
# AssetKey not subject to any further manipulation.
resolved_deps = ResolvedAssetDependencies(assets_defs, [])

input_asset_key_replacements = [
asset_key_replacements = [
{
raw_key: normalized_key
for input_name, raw_key in ad.keys_by_input_name.items()
Expand All @@ -218,8 +218,8 @@ def normalize_assets(

# Only update the assets defs if we're actually replacing input asset keys
assets_defs = [
ad.with_attributes(input_asset_key_replacements=reps) if reps else ad
for ad, reps in zip(assets_defs, input_asset_key_replacements)
ad.with_attributes(asset_key_replacements=reps) if reps else ad
for ad, reps in zip(assets_defs, asset_key_replacements)
]

# Create unexecutable external assets definitions for any referenced keys for which no
Expand Down
6 changes: 6 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/asset_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ def from_db_string(db_string: str) -> Optional["AssetCheckKey"]:
def to_db_string(self) -> str:
return seven.json.dumps({"asset_key": self.asset_key.to_string(), "check_name": self.name})

def with_asset_key_prefix(self, prefix: CoercibleToAssetKeyPrefix) -> "AssetCheckKey":
return AssetCheckKey(self.asset_key.with_prefix(prefix), self.name)

def replace_asset_key(self, asset_key: AssetKey) -> "AssetCheckKey":
return AssetCheckKey(asset_key, self.name)


EntityKey = Union[AssetKey, AssetCheckKey]
T_EntityKey = TypeVar("T_EntityKey", AssetKey, AssetCheckKey, EntityKey)
Expand Down
32 changes: 13 additions & 19 deletions python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,8 +1208,7 @@ def coerce_to_checks_def(self) -> "AssetChecksDefinition":
def with_attributes(
self,
*,
output_asset_key_replacements: Mapping[AssetKey, AssetKey] = {},
input_asset_key_replacements: Mapping[AssetKey, AssetKey] = {},
asset_key_replacements: Mapping[AssetKey, AssetKey] = {},
group_names_by_key: Mapping[AssetKey, str] = {},
tags_by_key: Mapping[AssetKey, Mapping[str, str]] = {},
freshness_policy: Optional[
Expand Down Expand Up @@ -1254,16 +1253,13 @@ def update_replace_dict_and_conflicts(
default_value=DEFAULT_GROUP_NAME,
)

if key in output_asset_key_replacements:
replace_dict["key"] = output_asset_key_replacements[key]
if key in asset_key_replacements:
replace_dict["key"] = asset_key_replacements[key]

if input_asset_key_replacements or output_asset_key_replacements:
if asset_key_replacements:
new_deps = []
for dep in spec.deps:
replacement_key = input_asset_key_replacements.get(
dep.asset_key,
output_asset_key_replacements.get(dep.asset_key),
)
replacement_key = asset_key_replacements.get(dep.asset_key, dep.asset_key)
if replacement_key is not None:
new_deps.append(dep._replace(asset_key=replacement_key))
else:
Expand All @@ -1280,33 +1276,31 @@ def update_replace_dict_and_conflicts(
)

check_specs_by_output_name = {
output_name: check_spec._replace(
asset_key=output_asset_key_replacements.get(
check_spec.asset_key, check_spec.asset_key
output_name: check_spec.replace_key(
key=check_spec.key.replace_asset_key(
asset_key_replacements.get(check_spec.asset_key, check_spec.asset_key)
)
)
for output_name, check_spec in self.node_check_specs_by_output_name.items()
}

selected_asset_check_keys = {
check_key._replace(
asset_key=output_asset_key_replacements.get(
check_key.asset_key, check_key.asset_key
)
check_key.replace_asset_key(
asset_key_replacements.get(check_key.asset_key, check_key.asset_key)
)
for check_key in self.check_keys
}

replaced_attributes = dict(
keys_by_input_name={
input_name: input_asset_key_replacements.get(key, key)
input_name: asset_key_replacements.get(key, key)
for input_name, key in self.node_keys_by_input_name.items()
},
keys_by_output_name={
output_name: output_asset_key_replacements.get(key, key)
output_name: asset_key_replacements.get(key, key)
for output_name, key in self.node_keys_by_output_name.items()
},
selected_asset_keys={output_asset_key_replacements.get(key, key) for key in self.keys},
selected_asset_keys={asset_key_replacements.get(key, key) for key in self.keys},
backfill_policy=backfill_policy if backfill_policy else self.backfill_policy,
is_subset=self.is_subset,
check_specs_by_output_name=check_specs_by_output_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,15 @@ def with_resources(

def with_attributes(
self,
output_asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
input_asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
group_names_by_key: Optional[Mapping[AssetKey, str]] = None,
freshness_policy: Optional[
Union[FreshnessPolicy, Mapping[AssetKey, FreshnessPolicy]]
] = None,
) -> "CacheableAssetsDefinition":
return PrefixOrGroupWrappedCacheableAssetsDefinition(
self,
output_asset_key_replacements=output_asset_key_replacements,
input_asset_key_replacements=input_asset_key_replacements,
asset_key_replacements=asset_key_replacements,
group_names_by_key=group_names_by_key,
freshness_policy=freshness_policy,
)
Expand Down Expand Up @@ -247,8 +245,7 @@ class PrefixOrGroupWrappedCacheableAssetsDefinition(WrappedCacheableAssetsDefini
def __init__(
self,
wrapped: CacheableAssetsDefinition,
output_asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
input_asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
asset_key_replacements: Optional[Mapping[AssetKey, AssetKey]] = None,
group_names_by_key: Optional[Mapping[AssetKey, str]] = None,
group_name_for_all_assets: Optional[str] = None,
prefix_for_all_assets: Optional[List[str]] = None,
Expand All @@ -260,8 +257,7 @@ def __init__(
] = None,
backfill_policy: Optional[BackfillPolicy] = None,
):
self._output_asset_key_replacements = output_asset_key_replacements or {}
self._input_asset_key_replacements = input_asset_key_replacements or {}
self._asset_key_replacements = asset_key_replacements or {}
self._group_names_by_key = group_names_by_key or {}
self._group_name_for_all_assets = group_name_for_all_assets
self._prefix_for_all_assets = prefix_for_all_assets
Expand All @@ -274,12 +270,8 @@ def __init__(
"Cannot set both group_name_for_all_assets and group_names_by_key",
)
check.invariant(
not (
prefix_for_all_assets
and (output_asset_key_replacements or input_asset_key_replacements)
),
"Cannot set both prefix_for_all_assets and output_asset_key_replacements or"
" input_asset_key_replacements",
not (prefix_for_all_assets and (asset_key_replacements)),
"Cannot set both prefix_for_all_assets and asset_key_replacements",
)

super().__init__(
Expand All @@ -290,22 +282,10 @@ def __init__(
def _get_hash(self) -> str:
"""Generate a stable hash of the various prefix/group mappings."""
contents = hashlib.sha1()
if self._output_asset_key_replacements:
if self._asset_key_replacements:
contents.update(
_map_to_hashable(
{
tuple(k.path): tuple(v.path)
for k, v in self._output_asset_key_replacements.items()
}
)
)
if self._input_asset_key_replacements:
contents.update(
_map_to_hashable(
{
tuple(k.path): tuple(v.path)
for k, v in self._input_asset_key_replacements.items()
}
{tuple(k.path): tuple(v.path) for k, v in self._asset_key_replacements.items()}
)
)
if self._group_names_by_key:
Expand All @@ -330,33 +310,13 @@ def transformed_assets_def(self, assets_def: AssetsDefinition) -> AssetsDefiniti
if self._group_name_for_all_assets
else self._group_names_by_key
)
output_asset_key_replacements = (
{
k: AssetKey(
path=(
self._prefix_for_all_assets + list(k.path)
if self._prefix_for_all_assets
else k.path
)
)
for k in assets_def.keys
}
if self._prefix_for_all_assets
else self._output_asset_key_replacements
)
input_asset_key_replacements = (
asset_key_replacements = (
{
k: AssetKey(
path=(
self._prefix_for_all_assets + list(k.path)
if self._prefix_for_all_assets
else k.path
)
)
for k in assets_def.dependency_keys
k: k.with_prefix(self._prefix_for_all_assets) if self._prefix_for_all_assets else k
for k in assets_def.keys | set(assets_def.dependency_keys)
}
if self._prefix_for_all_assets
else self._input_asset_key_replacements
else self._asset_key_replacements
)
if isinstance(self._auto_materialize_policy, dict):
automation_condition = {
Expand All @@ -367,8 +327,7 @@ def transformed_assets_def(self, assets_def: AssetsDefinition) -> AssetsDefiniti
else:
automation_condition = None
return assets_def.with_attributes(
output_asset_key_replacements=output_asset_key_replacements,
input_asset_key_replacements=input_asset_key_replacements,
asset_key_replacements=asset_key_replacements,
group_names_by_key=group_names_by_key,
freshness_policy=self._freshness_policy,
automation_condition=automation_condition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dagster._check as check
from dagster._core.definitions.asset_checks import AssetChecksDefinition, has_only_asset_checks
from dagster._core.definitions.asset_key import (
AssetCheckKey,
AssetKey,
CoercibleToAssetKeyPrefix,
check_opt_coercible_to_asset_key_prefix_param,
Expand Down Expand Up @@ -429,6 +430,7 @@ def find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]:
def replace_keys_in_asset(
asset: Union[AssetsDefinition, AssetSpec, SourceAsset],
key_replacements: Mapping[AssetKey, AssetKey],
check_key_replacements: Mapping[AssetCheckKey, AssetCheckKey],
) -> Union[AssetsDefinition, AssetSpec, SourceAsset]:
if isinstance(asset, SourceAsset):
return asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
Expand All @@ -438,13 +440,7 @@ def replace_keys_in_asset(
)
else:
updated_object = asset.with_attributes(
output_asset_key_replacements={
key: key_replacements.get(key, key)
for key in key_iterator(asset, included_targeted_keys=True)
},
input_asset_key_replacements={
key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values()
},
asset_key_replacements=key_replacements,
)
return (
updated_object.coerce_to_checks_def()
Expand All @@ -469,6 +465,10 @@ def assets_defs_and_specs(self) -> Sequence[Union[AssetsDefinition, AssetSpec]]:
or isinstance(dagster_object, AssetSpec)
]

@cached_property
def assets_defs(self) -> Sequence[AssetsDefinition]:
return [asset for asset in self.loaded_objects if isinstance(asset, AssetsDefinition)]

@cached_property
def checks_defs(self) -> Sequence[AssetChecksDefinition]:
return [
Expand Down Expand Up @@ -511,12 +511,21 @@ def assets_with_loadable_prefix(
for asset_object in self.assets_defs_specs_and_checks_defs
for key in key_iterator(asset_object, included_targeted_keys=True)
}
all_check_keys = {
check_key for asset_object in self.assets_defs for check_key in asset_object.check_keys
}

key_replacements = {key: key.with_prefix(key_prefix) for key in all_asset_keys}
check_key_replacements = {
check_key: check_key.with_asset_key_prefix(key_prefix) for check_key in all_check_keys
}
for asset_object in self.loaded_objects:
if isinstance(asset_object, CacheableAssetsDefinition):
result_list.append(asset_object.with_prefix_for_all(key_prefix))
elif isinstance(asset_object, AssetsDefinition):
result_list.append(replace_keys_in_asset(asset_object, key_replacements))
result_list.append(
replace_keys_in_asset(asset_object, key_replacements, check_key_replacements)
)
else:
# We don't replace the key for SourceAssets.
result_list.append(asset_object)
Expand All @@ -532,7 +541,9 @@ def assets_with_source_prefix(
}
for asset_object in self.loaded_objects:
if isinstance(asset_object, KeyScopedAssetObjects):
result_list.append(replace_keys_in_asset(asset_object, key_replacements))
result_list.append(
replace_keys_in_asset(asset_object, key_replacements, check_key_replacements={})
)
else:
result_list.append(asset_object)
return ResolvedAssetObjectList(result_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2363,10 +2363,10 @@ def test_asset_group_build_subset_job(job_selection, expected_assets, use_multi,
for prefix in reversed(prefixes or []):
all_assets = [
assets_def.with_attributes(
input_asset_key_replacements={
k: k.with_prefix(prefix) for k in assets_def.keys_by_input_name.values()
asset_key_replacements={
k: k.with_prefix(prefix)
for k in set(assets_def.keys_by_input_name.values()) | set(assets_def.keys)
},
output_asset_key_replacements={k: k.with_prefix(prefix) for k in assets_def.keys},
)
for assets_def in all_assets
]
Expand Down
Loading

0 comments on commit 5ddc747

Please sign in to comment.