From 2b6352042d6573885d21d0f4d4be4a71c194e8ac Mon Sep 17 00:00:00 2001 From: Chris DeCarolis Date: Mon, 16 Dec 2024 12:03:57 -0800 Subject: [PATCH] Include specs in asset module loaders --- .../definitions/load_assets_from_modules.py | 94 ++++++++++++------- .../asset_package/__init__.py | 4 + .../asset_package/module_with_assets.py | 4 + .../test_assets_from_modules.py | 47 +++++++++- 4 files changed, 113 insertions(+), 36 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py index 60560af0ae43a..a021e47af506c 100644 --- a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py +++ b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py @@ -63,8 +63,8 @@ def find_subclasses_in_module( yield value -LoadableAssetTypes = Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition] -KeyScopedAssetObjects = (AssetsDefinition, SourceAsset) +LoadableAssetTypes = Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition] +KeyScopedAssetObjects = (AssetsDefinition, AssetSpec, SourceAsset) class LoadedAssetsList: @@ -81,7 +81,8 @@ def from_modules(cls, modules: Iterable[ModuleType]) -> "LoadedAssetsList": { module.__name__: list( find_objects_in_module_of_types( - module, (AssetsDefinition, SourceAsset, CacheableAssetsDefinition) + module, + (AssetsDefinition, SourceAsset, CacheableAssetsDefinition, AssetSpec), ) ) for module in modules @@ -175,7 +176,7 @@ def _inner(spec: AssetSpec) -> AssetSpec: def key_iterator( - asset: Union[AssetsDefinition, SourceAsset], included_targeted_keys: bool = False + asset: Union[AssetsDefinition, SourceAsset, AssetSpec], included_targeted_keys: bool = False ) -> Iterator[AssetKey]: return ( iter( @@ -203,7 +204,8 @@ def load_assets_from_modules( automation_condition: Optional[AutomationCondition] = None, backfill_policy: Optional[BackfillPolicy] = None, source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None, -) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]: + include_specs: bool = False, +) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]: """Constructs a list of assets and source assets from the given modules. Args: @@ -226,6 +228,14 @@ def load_assets_from_modules( Sequence[Union[AssetsDefinition, SourceAsset]]: A list containing assets and source assets defined in the given modules. """ + + def _asset_filter(asset: LoadableAssetTypes) -> bool: + if isinstance(asset, AssetsDefinition) and not has_only_asset_checks(asset): + return True + if isinstance(asset, AssetSpec): + return include_specs + return True + return ( LoadedAssetsList.from_modules(modules) .to_post_load() @@ -245,7 +255,7 @@ def load_assets_from_modules( backfill_policy, "backfill_policy", BackfillPolicy ), ) - .assets_only + .get_objects(_asset_filter) ) @@ -258,7 +268,8 @@ def load_assets_from_current_module( automation_condition: Optional[AutomationCondition] = None, backfill_policy: Optional[BackfillPolicy] = None, source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None, -) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]: + include_specs: bool = False, +) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]: """Constructs a list of assets, source assets, and cacheable assets from the module where this function is called. @@ -296,6 +307,7 @@ def load_assets_from_current_module( ), backfill_policy=backfill_policy, source_key_prefix=source_key_prefix, + include_specs=include_specs, ) @@ -309,6 +321,7 @@ def load_assets_from_package_module( automation_condition: Optional[AutomationCondition] = None, backfill_policy: Optional[BackfillPolicy] = None, source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None, + include_specs: bool = False, ) -> Sequence[LoadableAssetTypes]: """Constructs a list of assets and source assets that includes all asset definitions, source assets, and cacheable assets in all sub-modules of the given package module. @@ -344,6 +357,7 @@ def load_assets_from_package_module( automation_condition=automation_condition, backfill_policy=backfill_policy, source_key_prefix=source_key_prefix, + include_specs=include_specs, ) @@ -356,7 +370,8 @@ def load_assets_from_package_name( auto_materialize_policy: Optional[AutoMaterializePolicy] = None, backfill_policy: Optional[BackfillPolicy] = None, source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None, -) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]: + include_specs: bool = False, +) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]: """Constructs a list of assets, source assets, and cacheable assets that includes all asset definitions and source assets in all sub-modules of the given package. @@ -389,6 +404,7 @@ def load_assets_from_package_name( auto_materialize_policy=auto_materialize_policy, backfill_policy=backfill_policy, source_key_prefix=source_key_prefix, + include_specs=include_specs, ) @@ -410,11 +426,17 @@ def find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]: def replace_keys_in_asset( - asset: Union[AssetsDefinition, SourceAsset], + asset: Union[AssetsDefinition, AssetSpec, SourceAsset], key_replacements: Mapping[AssetKey, AssetKey], -) -> Union[AssetsDefinition, SourceAsset]: - updated_object = ( - asset.with_attributes( +) -> Union[AssetsDefinition, AssetSpec, SourceAsset]: + if isinstance(asset, SourceAsset): + return asset.with_attributes(key=key_replacements.get(asset.key, asset.key)) + if isinstance(asset, AssetSpec): + return asset.replace_attributes( + key=key_replacements.get(asset.key, asset.key), + ) + else: + updated_object = asset.with_attributes( output_asset_key_replacements={ key: key_replacements.get(key, key) for key in key_iterator(asset, included_targeted_keys=True) @@ -423,34 +445,31 @@ def replace_keys_in_asset( key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values() }, ) - if isinstance(asset, AssetsDefinition) - else asset.with_attributes(key=key_replacements.get(asset.key, asset.key)) - ) - if isinstance(asset, AssetChecksDefinition): - updated_object = cast(AssetsDefinition, updated_object) - updated_object = AssetChecksDefinition.create( - keys_by_input_name=updated_object.keys_by_input_name, - node_def=updated_object.op, - check_specs_by_output_name=updated_object.check_specs_by_output_name, - resource_defs=updated_object.resource_defs, - can_subset=updated_object.can_subset, - ) - return updated_object + if isinstance(asset, AssetsDefinition) and has_only_asset_checks(asset): + updated_object = AssetChecksDefinition.create( + keys_by_input_name=updated_object.keys_by_input_name, + node_def=updated_object.op, + check_specs_by_output_name=updated_object.check_specs_by_output_name, + resource_defs=updated_object.resource_defs, + can_subset=updated_object.can_subset, + ) + return updated_object class ResolvedAssetObjectList: def __init__( self, - loaded_objects: Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]], + loaded_objects: Sequence[LoadableAssetTypes], ): self.loaded_objects = loaded_objects @cached_property - def assets_defs(self) -> Sequence[AssetsDefinition]: + def assets_defs_and_specs(self) -> Sequence[Union[AssetsDefinition, AssetSpec]]: return [ dagster_object for dagster_object in self.loaded_objects - if isinstance(dagster_object, AssetsDefinition) and dagster_object.keys + if (isinstance(dagster_object, AssetsDefinition) and dagster_object.keys) + or isinstance(dagster_object, AssetSpec) ] @cached_property @@ -462,8 +481,10 @@ def checks_defs(self) -> Sequence[AssetChecksDefinition]: ] @cached_property - def assets_and_checks_defs(self) -> Sequence[Union[AssetsDefinition, AssetChecksDefinition]]: - return [*self.assets_defs, *self.checks_defs] + def assets_defs_specs_and_checks_defs( + self, + ) -> Sequence[Union[AssetsDefinition, AssetSpec, AssetChecksDefinition]]: + return [*self.assets_defs_and_specs, *self.checks_defs] @cached_property def source_assets(self) -> Sequence[SourceAsset]: @@ -475,9 +496,10 @@ def cacheable_assets(self) -> Sequence[CacheableAssetsDefinition]: asset for asset in self.loaded_objects if isinstance(asset, CacheableAssetsDefinition) ] - @cached_property - def assets_only(self) -> Sequence[LoadableAssetTypes]: - return [*self.source_assets, *self.assets_defs, *self.cacheable_assets] + def get_objects( + self, filter_fn: Callable[[LoadableAssetTypes], bool] + ) -> Sequence[LoadableAssetTypes]: + return [asset for asset in self.loaded_objects if filter_fn(asset)] def assets_with_loadable_prefix( self, key_prefix: CoercibleToAssetKeyPrefix @@ -489,7 +511,7 @@ def assets_with_loadable_prefix( result_list = [] all_asset_keys = { key - for asset_object in self.assets_and_checks_defs + for asset_object in self.assets_defs_specs_and_checks_defs for key in key_iterator(asset_object, included_targeted_keys=True) } key_replacements = {key: key.with_prefix(key_prefix) for key in all_asset_keys} @@ -554,6 +576,10 @@ def with_attributes( return_list.append( asset.with_attributes(group_name=group_name if group_name else asset.group_name) ) + elif isinstance(asset, AssetSpec): + return_list.append( + _spec_mapper_disallow_group_override(group_name, automation_condition)(asset) + ) else: return_list.append( asset.with_attributes_for_all( diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_package/__init__.py b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_package/__init__.py index 9eaf0d44dad77..44830afc5c1fa 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_package/__init__.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_package/__init__.py @@ -1,4 +1,5 @@ from dagster import AssetKey, SourceAsset, asset +from dagster._core.definitions.asset_spec import AssetSpec @asset @@ -29,4 +30,7 @@ def make_list_of_source_assets(): return [buddy_holly, jerry_lee_lewis] +top_level_spec = AssetSpec("top_level_spec") + + list_of_assets_and_source_assets = [*make_list_of_assets(), *make_list_of_source_assets()] diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_package/module_with_assets.py b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_package/module_with_assets.py index 0cea67687a7f0..eaa8819f9fc57 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/asset_package/module_with_assets.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/asset_package/module_with_assets.py @@ -1,4 +1,5 @@ from dagster import AssetKey, SourceAsset, asset, graph_asset, op +from dagster._core.definitions.asset_spec import AssetSpec from dagster._core.definitions.metadata import ( CodeReferencesMetadataSet, CodeReferencesMetadataValue, @@ -41,3 +42,6 @@ def multiply_by_two(input_num): @graph_asset def graph_backed_asset(): return multiply_by_two(one()) + + +my_spec = AssetSpec("my_asset_spec") diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_from_modules.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_from_modules.py index ea7f2bc28d5d0..97d9549e615da 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_from_modules.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_from_modules.py @@ -14,6 +14,7 @@ load_assets_from_package_module, load_assets_from_package_name, ) +from dagster._core.definitions.asset_spec import AssetSpec from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition @@ -101,6 +102,15 @@ def test_load_assets_from_package_name(): assert assets_1 == assets_2 + assets_3 = load_assets_from_package_name(asset_package.__name__, include_specs=True) + assert len(assets_3) == 13 + + assert next( + iter( + a for a in assets_3 if isinstance(a, AssetSpec) and a.key == AssetKey("top_level_spec") + ) + ) + def test_load_assets_from_package_module(): from dagster_tests.asset_defs_tests import asset_package @@ -117,6 +127,15 @@ def test_load_assets_from_package_module(): assert assets_1 == assets_2 + assets_3 = load_assets_from_package_name(asset_package.__name__, include_specs=True) + assert len(assets_3) == 13 + + assert next( + iter( + a for a in assets_3 if isinstance(a, AssetSpec) and a.key == AssetKey("top_level_spec") + ) + ) + def test_load_assets_from_modules(monkeypatch): from dagster_tests.asset_defs_tests import asset_package @@ -145,6 +164,19 @@ def little_richard(): ): load_assets_from_modules([asset_package, module_with_assets]) + # Create an AssetsDefinition with an identical spec to that in the module + with monkeypatch.context() as m: + + @asset + def top_level_spec(): + pass + + m.setattr(asset_package, "top_level_spec_same_assets_def", top_level_spec, raising=False) + with pytest.raises( + DagsterInvalidDefinitionError, + ): + load_assets_from_modules([asset_package, module_with_assets], include_specs=True) + @asset(group_name="my_group") def asset_in_current_module(): @@ -153,12 +185,22 @@ def asset_in_current_module(): source_asset_in_current_module = SourceAsset(AssetKey("source_asset_in_current_module")) +spec_in_current_module = AssetSpec("spec_in_current_module") + def test_load_assets_from_current_module(): assets = load_assets_from_current_module() assets = [get_unique_asset_identifier(asset) for asset in assets] - assert assets == ["asset_in_current_module", AssetKey("source_asset_in_current_module")] + assert set(assets) == {"asset_in_current_module", AssetKey("source_asset_in_current_module")} assert len(assets) == 2 + assets = load_assets_from_current_module(include_specs=True) + assets = [get_unique_asset_identifier(asset) for asset in assets] + assert len(assets) == 3 + assert set(assets) == { + "asset_in_current_module", + AssetKey("source_asset_in_current_module"), + AssetKey("spec_in_current_module"), + } def test_load_assets_from_modules_with_group_name(): @@ -176,7 +218,8 @@ def test_load_assets_from_modules_with_group_name(): def test_respect_existing_groups(): assets = load_assets_from_current_module() - assert assets[0].group_names_by_key.get(AssetKey("asset_in_current_module")) == "my_group" # pyright: ignore[reportAttributeAccessIssue] + assets_def = next(iter(a for a in assets if isinstance(a, AssetsDefinition))) + assert assets_def.group_names_by_key.get(AssetKey("asset_in_current_module")) == "my_group" # pyright: ignore[reportAttributeAccessIssue] with pytest.raises(DagsterInvalidDefinitionError): load_assets_from_current_module(group_name="yay")