From e3763d4612c77241b7995780cb840cf5bf8ceb79 Mon Sep 17 00:00:00 2001 From: Christopher DeCarolis Date: Thu, 19 Dec 2024 06:19:50 -0800 Subject: [PATCH] [module-loaders] Further simplify asset checks path (#26527) ## Summary & Motivation We can further simplify the load_asset_from_x code path by making the asset checks loader return AssetsDefinitions, and changing the assertions to just make sure that the returned assets only contain checks. ## How I Tested These Changes Altered existing tests --- .../dagster/_core/definitions/assets.py | 25 +++++++++++++++++ .../definitions/load_assets_from_modules.py | 28 +++++++------------ 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/assets.py b/python_modules/dagster/dagster/_core/definitions/assets.py index 064507025685d..c345df8cb6faf 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets.py +++ b/python_modules/dagster/dagster/_core/definitions/assets.py @@ -80,6 +80,7 @@ from dagster._utils.warnings import ExperimentalWarning, disable_dagster_warnings if TYPE_CHECKING: + from dagster._core.definitions.asset_checks import AssetChecksDefinition from dagster._core.definitions.graph_definition import GraphDefinition ASSET_SUBSET_INPUT_PREFIX = "__subset_input__" @@ -1180,6 +1181,30 @@ def get_op_def_for_asset_key(self, key: AssetKey) -> Optional[OpDefinition]: output_name = self.get_output_name_for_asset_key(key) return self.node_def.resolve_output_to_origin_op_def(output_name) + def coerce_to_checks_def(self) -> "AssetChecksDefinition": + from dagster._core.definitions.asset_checks import ( + AssetChecksDefinition, + has_only_asset_checks, + ) + + if not has_only_asset_checks(self): + raise DagsterInvalidDefinitionError( + "Cannot coerce an AssetsDefinition to an AssetChecksDefinition if it contains " + "non-check assets." + ) + if len(self.check_keys) == 0: + raise DagsterInvalidDefinitionError( + "Cannot coerce an AssetsDefinition to an AssetChecksDefinition if it contains no " + "checks." + ) + return AssetChecksDefinition.create( + keys_by_input_name=self.keys_by_input_name, + node_def=self.op, + check_specs_by_output_name=self.check_specs_by_output_name, + resource_defs=self.resource_defs, + can_subset=self.can_subset, + ) + def with_attributes( self, *, diff --git a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py index 63cd3730bff43..bdec14601d014 100644 --- a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py +++ b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py @@ -446,15 +446,11 @@ def replace_keys_in_asset( key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values() }, ) - if isinstance(asset, AssetsDefinition) and has_only_asset_checks(asset): - updated_object = AssetChecksDefinition.create( - keys_by_input_name=updated_object.keys_by_input_name, - node_def=updated_object.op, - check_specs_by_output_name=updated_object.check_specs_by_output_name, - resource_defs=updated_object.resource_defs, - can_subset=updated_object.can_subset, - ) - return updated_object + return ( + updated_object.coerce_to_checks_def() + if has_only_asset_checks(updated_object) + else updated_object + ) class ResolvedAssetObjectList: @@ -564,15 +560,11 @@ def with_attributes( ).with_attributes( backfill_policy=backfill_policy, freshness_policy=freshness_policy ) - if isinstance(asset, AssetChecksDefinition): - new_asset = AssetChecksDefinition.create( - keys_by_input_name=new_asset.keys_by_input_name, - node_def=new_asset.op, - check_specs_by_output_name=new_asset.check_specs_by_output_name, - resource_defs=new_asset.resource_defs, - can_subset=new_asset.can_subset, - ) - return_list.append(new_asset) + return_list.append( + new_asset.coerce_to_checks_def() + if has_only_asset_checks(new_asset) + else new_asset + ) elif isinstance(asset, SourceAsset): return_list.append( asset.with_attributes(group_name=group_name if group_name else asset.group_name)