From c63cd60a563215d6ae1243e997b2f4d5ec65a548 Mon Sep 17 00:00:00 2001 From: Chris DeCarolis Date: Fri, 13 Dec 2024 12:24:09 -0800 Subject: [PATCH] Simplify module asset loader code path --- .../definitions/load_assets_from_modules.py | 142 ++++++++++++------ .../test_assets_from_modules.py | 2 +- 2 files changed, 100 insertions(+), 44 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py index a4c442476e861..a71f7d5a890fd 100644 --- a/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py +++ b/python_modules/dagster/dagster/_core/definitions/load_assets_from_modules.py @@ -1,8 +1,10 @@ import inspect import pkgutil +from collections import defaultdict +from functools import cached_property from importlib import import_module from types import ModuleType -from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, cast +from typing import Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union import dagster._check as check from dagster._core.definitions.asset_key import ( @@ -47,6 +49,96 @@ def find_subclasses_in_module( yield value +LoadableAssetTypes = Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition] +KeyScopedAssetObjects = (AssetsDefinition, SourceAsset) + + +class LoadedAssetsList: + def __init__( + self, + assets_per_module: Mapping[str, Sequence[LoadableAssetTypes]], + ): + self.assets_per_module = assets_per_module + self._do_collision_detection() + + @classmethod + def from_modules(cls, modules: Iterable[ModuleType]) -> "LoadedAssetsList": + return cls( + { + module.__name__: list( + find_objects_in_module_of_types( + module, (AssetsDefinition, SourceAsset, CacheableAssetsDefinition) + ) + ) + for module in modules + }, + ) + + @cached_property + def flat_object_list(self) -> Sequence[LoadableAssetTypes]: + return [ + asset_object for objects in self.assets_per_module.values() for asset_object in objects + ] + + @cached_property + def objects_by_id(self) -> Dict[int, LoadableAssetTypes]: + return {id(asset_object): asset_object for asset_object in self.flat_object_list} + + @cached_property + def deduped_objects(self) -> Sequence[LoadableAssetTypes]: + return list(self.objects_by_id.values()) + + @cached_property + def module_name_by_id(self) -> Dict[int, str]: + return { + id(asset_object): module_name + for module_name, objects in self.assets_per_module.items() + for asset_object in objects + } + + @cached_property + def objects_by_key(self) -> Mapping[AssetKey, Sequence[Union[SourceAsset, AssetsDefinition]]]: + objects_by_key = defaultdict(list) + for asset_object in self.flat_object_list: + if not isinstance(asset_object, KeyScopedAssetObjects): + continue + for key in key_iterator(asset_object): + objects_by_key[key].append(asset_object) + return objects_by_key + + @cached_property + def sources(self) -> Sequence[SourceAsset]: + return [asset for asset in self.deduped_objects if isinstance(asset, SourceAsset)] + + @cached_property + def assets_defs(self) -> Sequence[AssetsDefinition]: + return [ + asset + for asset in self.deduped_objects + if isinstance(asset, AssetsDefinition) and asset.keys + ] + + @cached_property + def cacheable_assets_defs(self) -> Sequence[CacheableAssetsDefinition]: + return [ + asset for asset in self.deduped_objects if isinstance(asset, CacheableAssetsDefinition) + ] + + def _do_collision_detection(self) -> None: + for key, asset_objects in self.objects_by_key.items(): + # If there is more than one asset_object in the list for a given key, and the objects do not refer to the same asset_object in memory, we have a collision. + num_distinct_objects_for_key = len( + set(id(asset_object) for asset_object in asset_objects) + ) + if len(asset_objects) > 1 and num_distinct_objects_for_key > 1: + asset_objects_str = ", ".join( + set(self.module_name_by_id[id(asset_object)] for asset_object in asset_objects) + ) + raise DagsterInvalidDefinitionError( + f"Asset key {key.to_user_string()} is defined multiple times. Definitions found in modules: {asset_objects_str}." + ) + + def assets_from_modules( modules: Iterable[ModuleType], ) -> Tuple[Sequence[AssetsDefinition], Sequence[SourceAsset], Sequence[CacheableAssetsDefinition]]: @@ -63,48 +155,12 @@ def assets_from_modules( A tuple containing a list of assets, a list of source assets, and a list of cacheable assets defined in the given modules. """ - asset_ids: Set[int] = set() - asset_keys: Dict[AssetKey, ModuleType] = dict() - source_assets: List[SourceAsset] = [] - cacheable_assets: List[CacheableAssetsDefinition] = [] - assets: Dict[AssetKey, AssetsDefinition] = {} - for module in modules: - for asset in find_objects_in_module_of_types( - module, (AssetsDefinition, SourceAsset, CacheableAssetsDefinition) - ): - asset = cast(Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition], asset) - if id(asset) not in asset_ids: - asset_ids.add(id(asset)) - if isinstance(asset, CacheableAssetsDefinition): - cacheable_assets.append(asset) - else: - keys = asset.keys if isinstance(asset, AssetsDefinition) else [asset.key] - for key in keys: - if key in asset_keys: - modules_str = ", ".join( - set([asset_keys[key].__name__, module.__name__]) - ) - error_str = ( - f"Asset key {key} is defined multiple times. Definitions found in" - f" modules: {modules_str}. " - ) - - if key in assets and isinstance(asset, AssetsDefinition): - if assets[key].node_def == asset.node_def: - error_str += ( - "One possible cause of this bug is a call to with_resources" - " outside of a repository definition, causing a duplicate" - " asset definition." - ) - - raise DagsterInvalidDefinitionError(error_str) - else: - asset_keys[key] = module - if isinstance(asset, AssetsDefinition): - assets[key] = asset - if isinstance(asset, SourceAsset): - source_assets.append(asset) - return list(set(assets.values())), source_assets, cacheable_assets + assets_list = LoadedAssetsList.from_modules(modules) + return assets_list.assets_defs, assets_list.sources, assets_list.cacheable_assets_defs + + +def key_iterator(asset: Union[AssetsDefinition, SourceAsset]) -> Iterator[AssetKey]: + return iter(asset.keys) if isinstance(asset, AssetsDefinition) else iter([asset.key]) def load_assets_from_modules( diff --git a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_from_modules.py b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_from_modules.py index 7fc579483ac2e..6ff6b65131cb4 100644 --- a/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_from_modules.py +++ b/python_modules/dagster/dagster_tests/asset_defs_tests/test_assets_from_modules.py @@ -142,7 +142,7 @@ def little_richard(): with pytest.raises( DagsterInvalidDefinitionError, match=re.escape( - "Asset key AssetKey(['little_richard']) is defined multiple times. " + "Asset key little_richard is defined multiple times. " "Definitions found in modules: dagster_tests.asset_defs_tests.asset_package." ), ):