diff --git a/python_modules/dagster/dagster/_core/definitions/assets.py b/python_modules/dagster/dagster/_core/definitions/assets.py index 064507025685d..c345df8cb6faf 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets.py +++ b/python_modules/dagster/dagster/_core/definitions/assets.py @@ -80,6 +80,7 @@ from dagster._utils.warnings import ExperimentalWarning, disable_dagster_warnings if TYPE_CHECKING: + from dagster._core.definitions.asset_checks import AssetChecksDefinition from dagster._core.definitions.graph_definition import GraphDefinition ASSET_SUBSET_INPUT_PREFIX = "__subset_input__" @@ -1180,6 +1181,30 @@ def get_op_def_for_asset_key(self, key: AssetKey) -> Optional[OpDefinition]: output_name = self.get_output_name_for_asset_key(key) return self.node_def.resolve_output_to_origin_op_def(output_name) + def coerce_to_checks_def(self) -> "AssetChecksDefinition": + from dagster._core.definitions.asset_checks import ( + AssetChecksDefinition, + has_only_asset_checks, + ) + + if not has_only_asset_checks(self): + raise DagsterInvalidDefinitionError( + "Cannot coerce an AssetsDefinition to an AssetChecksDefinition if it contains " + "non-check assets." + ) + if len(self.check_keys) == 0: + raise DagsterInvalidDefinitionError( + "Cannot coerce an AssetsDefinition to an AssetChecksDefinition if it contains no " + "checks." + ) + return AssetChecksDefinition.create( + keys_by_input_name=self.keys_by_input_name, + node_def=self.op, + check_specs_by_output_name=self.check_specs_by_output_name, + resource_defs=self.resource_defs, + can_subset=self.can_subset, + ) + def with_attributes( self, *, diff --git a/python_modules/dagster/dagster/_core/definitions/load_asset_checks_from_modules.py b/python_modules/dagster/dagster/_core/definitions/load_asset_checks_from_modules.py index a587066410f00..28b8afb031242 100644 --- a/python_modules/dagster/dagster/_core/definitions/load_asset_checks_from_modules.py +++ b/python_modules/dagster/dagster/_core/definitions/load_asset_checks_from_modules.py @@ -4,11 +4,11 @@ from typing import Iterable, Optional, Sequence import dagster._check as check -from dagster._core.definitions.asset_checks import AssetChecksDefinition from dagster._core.definitions.asset_key import ( CoercibleToAssetKeyPrefix, check_opt_coercible_to_asset_key_prefix_param, ) +from dagster._core.definitions.assets import AssetsDefinition from dagster._core.definitions.load_assets_from_modules import ( LoadedAssetsList, find_modules_in_package, @@ -18,7 +18,7 @@ def load_asset_checks_from_modules( modules: Iterable[ModuleType], asset_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None, -) -> Sequence[AssetChecksDefinition]: +) -> Sequence[AssetsDefinition]: """Constructs a list of asset checks from the given modules. This is most often used in conjunction with a call to `load_assets_from_modules`. @@ -52,7 +52,7 @@ def load_asset_checks_from_modules( def load_asset_checks_from_current_module( asset_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None, -) -> Sequence[AssetChecksDefinition]: +) -> Sequence[AssetsDefinition]: """Constructs a list of asset checks from the module where this function is called. This is most often used in conjunction with a call to `load_assets_from_current_module`. @@ -79,7 +79,7 @@ def load_asset_checks_from_current_module( def load_asset_checks_from_package_module( package_module: ModuleType, asset_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None -) -> Sequence[AssetChecksDefinition]: +) -> Sequence[AssetsDefinition]: """Constructs a list of asset checks from all sub-modules of the given package module. This is most often used in conjunction with a call to `load_assets_from_package_module`. @@ -104,7 +104,7 @@ def load_asset_checks_from_package_module( def load_asset_checks_from_package_name( package_name: str, asset_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None -) -> Sequence[AssetChecksDefinition]: +) -> Sequence[AssetsDefinition]: """Constructs a list of asset checks from all sub-modules of the given package. This is most often used in conjunction with a call to `load_assets_from_package_name`. diff --git a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py index 63cd3730bff43..bdec14601d014 100644 --- a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py +++ b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py @@ -446,15 +446,11 @@ def replace_keys_in_asset( key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values() }, ) - if isinstance(asset, AssetsDefinition) and has_only_asset_checks(asset): - updated_object = AssetChecksDefinition.create( - keys_by_input_name=updated_object.keys_by_input_name, - node_def=updated_object.op, - check_specs_by_output_name=updated_object.check_specs_by_output_name, - resource_defs=updated_object.resource_defs, - can_subset=updated_object.can_subset, - ) - return updated_object + return ( + updated_object.coerce_to_checks_def() + if has_only_asset_checks(updated_object) + else updated_object + ) class ResolvedAssetObjectList: @@ -564,15 +560,11 @@ def with_attributes( ).with_attributes( backfill_policy=backfill_policy, freshness_policy=freshness_policy ) - if isinstance(asset, AssetChecksDefinition): - new_asset = AssetChecksDefinition.create( - keys_by_input_name=new_asset.keys_by_input_name, - node_def=new_asset.op, - check_specs_by_output_name=new_asset.check_specs_by_output_name, - resource_defs=new_asset.resource_defs, - can_subset=new_asset.can_subset, - ) - return_list.append(new_asset) + return_list.append( + new_asset.coerce_to_checks_def() + if has_only_asset_checks(new_asset) + else new_asset + ) elif isinstance(asset, SourceAsset): return_list.append( asset.with_attributes(group_name=group_name if group_name else asset.group_name) diff --git a/python_modules/dagster/dagster_tests/definitions_tests/asset_check_tests/test_load_from_modules.py b/python_modules/dagster/dagster_tests/definitions_tests/asset_check_tests/test_load_from_modules.py index c299c488aae4a..34d62a99d8142 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/asset_check_tests/test_load_from_modules.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/asset_check_tests/test_load_from_modules.py @@ -1,6 +1,5 @@ import pytest from dagster import ( - AssetChecksDefinition, AssetKey, Definitions, asset_check, @@ -12,6 +11,7 @@ load_assets_from_package_module, load_assets_from_package_name, ) +from dagster._core.definitions.asset_checks import has_only_asset_checks from dagster_tests.definitions_tests.decorators_tests.test_asset_check_decorator import ( execute_assets_and_checks, @@ -24,7 +24,7 @@ def test_load_asset_checks_from_modules(): checks = load_asset_checks_from_modules([checks_module]) assert len(checks) == 1 - assert all(isinstance(check, AssetChecksDefinition) for check in checks) + assert all(has_only_asset_checks(check) for check in checks) asset_check_1_key = next(iter(asset_check_1.check_keys)) @@ -51,7 +51,7 @@ def test_load_asset_checks_from_modules_prefix(): checks = load_asset_checks_from_modules([checks_module], asset_key_prefix="foo") assert len(checks) == 1 - assert all(isinstance(check, AssetChecksDefinition) for check in checks) + assert all(has_only_asset_checks(check) for check in checks) check_key = next(iter(checks[0].check_keys)) assert check_key.asset_key == AssetKey(["foo", "asset_1"]) @@ -79,7 +79,7 @@ def check_in_current_module(): def test_load_asset_checks_from_current_module(): checks = load_asset_checks_from_current_module(asset_key_prefix="foo") assert len(checks) == 1 - assert all(isinstance(check, AssetChecksDefinition) for check in checks) + assert all(has_only_asset_checks(check) for check in checks) check_key = next(iter(checks[0].check_keys)) assert check_key.name == "check_in_current_module" assert check_key.asset_key == AssetKey(["foo", "asset_1"]) @@ -104,7 +104,7 @@ def test_load_asset_checks_from_package(load_fns): checks = checks_load_fn(checks_module, asset_key_prefix="foo") assert len(checks) == 2 - assert all(isinstance(check, AssetChecksDefinition) for check in checks) + assert all(has_only_asset_checks(check) for check in checks) check_key_0 = next(iter(checks[0].check_keys)) assert check_key_0.name == "asset_check_1" assert check_key_0.asset_key == AssetKey(["foo", "asset_1"])