Skip to content

Commit

Permalink
Make asset check coercion explicit in with_attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Dec 16, 2024
1 parent 1c711a7 commit 7edd532
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ def get_python_identifier(self) -> str:
@property
def key(self) -> AssetCheckKey:
return AssetCheckKey(self.asset_key, self.name)

def replace_key(self, key: AssetCheckKey) -> "AssetCheckSpec":
return self._replace(asset_key=key.asset_key, name=key.name)
3 changes: 3 additions & 0 deletions python_modules/dagster/dagster/_core/definitions/asset_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ def from_db_string(db_string: str) -> Optional["AssetCheckKey"]:
def to_db_string(self) -> str:
return seven.json.dumps({"asset_key": self.asset_key.to_string(), "check_name": self.name})

def with_asset_key_prefix(self, prefix: CoercibleToAssetKeyPrefix) -> "AssetCheckKey":
return AssetCheckKey(self.asset_key.with_prefix(prefix), self.name)


EntityKey = Union[AssetKey, AssetCheckKey]
T_EntityKey = TypeVar("T_EntityKey", AssetKey, AssetCheckKey, EntityKey)
Expand Down
14 changes: 4 additions & 10 deletions python_modules/dagster/dagster/_core/definitions/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,7 @@ def with_attributes(
self,
*,
output_asset_key_replacements: Mapping[AssetKey, AssetKey] = {},
output_check_key_replacements: Mapping[AssetCheckKey, AssetCheckKey] = {},
input_asset_key_replacements: Mapping[AssetKey, AssetKey] = {},
group_names_by_key: Mapping[AssetKey, str] = {},
tags_by_key: Mapping[AssetKey, Mapping[str, str]] = {},
Expand Down Expand Up @@ -1255,21 +1256,14 @@ def update_replace_dict_and_conflicts(
)

check_specs_by_output_name = {
output_name: check_spec._replace(
asset_key=output_asset_key_replacements.get(
check_spec.asset_key, check_spec.asset_key
)
output_name: check_spec.replace_key(
key=output_check_key_replacements.get(check_spec.key, check_spec.key)
)
for output_name, check_spec in self.node_check_specs_by_output_name.items()
}

selected_asset_check_keys = {
check_key._replace(
asset_key=output_asset_key_replacements.get(
check_key.asset_key, check_key.asset_key
)
)
for check_key in self.check_keys
output_check_key_replacements.get(key, key) for key in self.check_keys
}

replaced_attributes = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,27 +332,21 @@ def transformed_assets_def(self, assets_def: AssetsDefinition) -> AssetsDefiniti
)
output_asset_key_replacements = (
{
k: AssetKey(
path=(
self._prefix_for_all_assets + list(k.path)
if self._prefix_for_all_assets
else k.path
)
)
k: k.with_prefix(self._prefix_for_all_assets) if self._prefix_for_all_assets else k
for k in assets_def.keys
}
if self._prefix_for_all_assets
else self._output_asset_key_replacements
)
check_key_replacements = {
k: k.with_asset_key_prefix(self._prefix_for_all_assets)
if self._prefix_for_all_assets
else k
for k in assets_def.check_keys
}
input_asset_key_replacements = (
{
k: AssetKey(
path=(
self._prefix_for_all_assets + list(k.path)
if self._prefix_for_all_assets
else k.path
)
)
k: k.with_prefix(self._prefix_for_all_assets) if self._prefix_for_all_assets else k
for k in assets_def.dependency_keys
}
if self._prefix_for_all_assets
Expand All @@ -369,6 +363,7 @@ def transformed_assets_def(self, assets_def: AssetsDefinition) -> AssetsDefiniti
return assets_def.with_attributes(
output_asset_key_replacements=output_asset_key_replacements,
input_asset_key_replacements=input_asset_key_replacements,
output_check_key_replacements=check_key_replacements,
group_names_by_key=group_names_by_key,
freshness_policy=self._freshness_policy,
automation_condition=automation_condition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dagster._check as check
from dagster._core.definitions.asset_checks import AssetChecksDefinition, has_only_asset_checks
from dagster._core.definitions.asset_key import (
AssetCheckKey,
AssetKey,
CoercibleToAssetKeyPrefix,
check_opt_coercible_to_asset_key_prefix_param,
Expand Down Expand Up @@ -423,6 +424,7 @@ def find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]:
def replace_keys_in_asset(
asset: Union[AssetsDefinition, AssetSpec, SourceAsset],
key_replacements: Mapping[AssetKey, AssetKey],
check_key_replacements: Mapping[AssetCheckKey, AssetCheckKey],
) -> Union[AssetsDefinition, AssetSpec, SourceAsset]:
if isinstance(asset, SourceAsset):
return asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
Expand All @@ -433,8 +435,10 @@ def replace_keys_in_asset(
else:
updated_object = asset.with_attributes(
output_asset_key_replacements={
key: key_replacements.get(key, key)
for key in key_iterator(asset, included_targeted_keys=True)
key: key_replacements.get(key, key) for key in asset.keys
},
output_check_key_replacements={
key: check_key_replacements.get(key, key) for key in asset.check_keys
},
input_asset_key_replacements={
key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values()
Expand All @@ -458,6 +462,10 @@ def assets_defs_and_specs(self) -> Sequence[Union[AssetsDefinition, AssetSpec]]:
if (isinstance(asset, AssetsDefinition) and asset.keys) or isinstance(asset, AssetSpec)
]

@cached_property
def assets_defs(self) -> Sequence[AssetsDefinition]:
return [asset for asset in self.loaded_objects if isinstance(asset, AssetsDefinition)]

@cached_property
def checks_defs(self) -> Sequence[AssetChecksDefinition]:
return [
Expand Down Expand Up @@ -500,12 +508,21 @@ def assets_with_loadable_prefix(
for asset_object in self.assets_defs_specs_and_checks_defs
for key in key_iterator(asset_object, included_targeted_keys=True)
}
all_check_keys = {
check_key for asset_object in self.assets_defs for check_key in asset_object.check_keys
}

key_replacements = {key: key.with_prefix(key_prefix) for key in all_asset_keys}
check_key_replacements = {
check_key: check_key.with_asset_key_prefix(key_prefix) for check_key in all_check_keys
}
for asset_object in self.loaded_objects:
if isinstance(asset_object, CacheableAssetsDefinition):
result_list.append(asset_object.with_prefix_for_all(key_prefix))
elif isinstance(asset_object, AssetsDefinition):
result_list.append(replace_keys_in_asset(asset_object, key_replacements))
result_list.append(
replace_keys_in_asset(asset_object, key_replacements, check_key_replacements)
)
else:
# We don't replace the key for SourceAssets.
result_list.append(asset_object)
Expand All @@ -521,7 +538,9 @@ def assets_with_source_prefix(
}
for asset_object in self.loaded_objects:
if isinstance(asset_object, KeyScopedAssetObjects):
result_list.append(replace_keys_in_asset(asset_object, key_replacements))
result_list.append(
replace_keys_in_asset(asset_object, key_replacements, check_key_replacements={})
)
else:
result_list.append(asset_object)
return ResolvedAssetObjectList(result_list)
Expand Down

0 comments on commit 7edd532

Please sign in to comment.