Skip to content

Commit

Permalink
Genericize object list utils
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Dec 19, 2024
1 parent b5065bf commit d4430a7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
CoercibleToAssetKeyPrefix,
check_opt_coercible_to_asset_key_prefix_param,
)
from dagster._core.definitions.module_loaders.object_list import LoadedAssetsList
from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects
from dagster._core.definitions.module_loaders.utils import find_modules_in_package


Expand All @@ -34,8 +34,8 @@ def load_asset_checks_from_modules(
asset_key_prefix, "asset_key_prefix"
)
return (
LoadedAssetsList.from_modules(modules)
.to_post_load()
ModuleScopedDagsterObjects.from_modules(modules)
.get_object_list()
.with_attributes(
key_prefix=asset_key_prefix,
source_key_prefix=None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from importlib import import_module
from types import ModuleType
from typing import Iterable, Iterator, Optional, Sequence, Tuple, Type, Union
from typing import Iterable, Iterator, Optional, Sequence, Tuple, Type, Union, cast

import dagster._check as check
from dagster._core.definitions.asset_checks import has_only_asset_checks
Expand All @@ -13,17 +13,17 @@
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.backfill_policy import BackfillPolicy
from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition
from dagster._core.definitions.declarative_automation.automation_condition import (
AutomationCondition,
)
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.module_loaders.object_list import LoadedAssetsList
from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects
from dagster._core.definitions.module_loaders.utils import (
LoadableAssetTypes,
LoadableAssetObject,
LoadableDagsterObject,
RuntimeAssetObjectTypes,
find_modules_in_package,
)
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.definitions.utils import resolve_automation_condition


Expand Down Expand Up @@ -62,7 +62,7 @@ def load_assets_from_modules(
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
) -> Sequence[LoadableAssetObject]:
"""Constructs a list of assets and source assets from the given modules.
Args:
Expand All @@ -86,17 +86,18 @@ def load_assets_from_modules(
A list containing assets and source assets defined in the given modules.
"""

def _asset_filter(asset: LoadableAssetTypes) -> bool:
if isinstance(asset, AssetsDefinition):
def _asset_filter(dagster_object: LoadableDagsterObject) -> bool:
if isinstance(dagster_object, AssetsDefinition):
# We don't load asset checks with asset module loaders.
return not has_only_asset_checks(asset)
if isinstance(asset, AssetSpec):
return not has_only_asset_checks(dagster_object)
if isinstance(dagster_object, AssetSpec):
return include_specs
return True
return isinstance(dagster_object, RuntimeAssetObjectTypes)

return (
LoadedAssetsList.from_modules(modules)
.to_post_load()
return cast(
Sequence[LoadableAssetObject],
ModuleScopedDagsterObjects.from_modules(modules)
.get_object_list()
.with_attributes(
key_prefix=check_opt_coercible_to_asset_key_prefix_param(key_prefix, "key_prefix"),
source_key_prefix=check_opt_coercible_to_asset_key_prefix_param(
Expand All @@ -113,7 +114,7 @@ def _asset_filter(asset: LoadableAssetTypes) -> bool:
backfill_policy, "backfill_policy", BackfillPolicy
),
)
.get_objects(_asset_filter)
.get_objects(_asset_filter),
)


Expand All @@ -127,7 +128,7 @@ def load_assets_from_current_module(
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
) -> Sequence[LoadableAssetObject]:
"""Constructs a list of assets, source assets, and cacheable assets from the module where
this function is called.
Expand Down Expand Up @@ -180,7 +181,7 @@ def load_assets_from_package_module(
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
include_specs: bool = False,
) -> Sequence[LoadableAssetTypes]:
) -> Sequence[LoadableAssetObject]:
"""Constructs a list of assets and source assets that includes all asset
definitions, source assets, and cacheable assets in all sub-modules of the given package module.
Expand Down Expand Up @@ -229,7 +230,7 @@ def load_assets_from_package_name(
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
) -> Sequence[LoadableAssetObject]:
"""Constructs a list of assets, source assets, and cacheable assets that includes all asset
definitions and source assets in all sub-modules of the given package.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
)
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.module_loaders.utils import (
KeyScopedAssetObjects,
LoadableAssetTypes,
LoadableDagsterObject,
RuntimeAssetObjectTypes,
RuntimeKeyScopedAssetObjectTypes,
find_objects_in_module_of_types,
key_iterator,
replace_keys_in_asset,
Expand All @@ -25,16 +26,16 @@
from dagster._core.errors import DagsterInvalidDefinitionError


class LoadedAssetsList:
class ModuleScopedDagsterObjects:
def __init__(
self,
assets_per_module: Mapping[str, Sequence[LoadableAssetTypes]],
objects_per_module: Mapping[str, Sequence[LoadableDagsterObject]],
):
self.assets_per_module = assets_per_module
self.objects_per_module = objects_per_module
self._do_collision_detection()

@classmethod
def from_modules(cls, modules: Iterable[ModuleType]) -> "LoadedAssetsList":
def from_modules(cls, modules: Iterable[ModuleType]) -> "ModuleScopedDagsterObjects":
return cls(
{
module.__name__: list(
Expand All @@ -48,17 +49,17 @@ def from_modules(cls, modules: Iterable[ModuleType]) -> "LoadedAssetsList":
)

@cached_property
def flat_object_list(self) -> Sequence[LoadableAssetTypes]:
def flat_object_list(self) -> Sequence[LoadableDagsterObject]:
return [
asset_object for objects in self.assets_per_module.values() for asset_object in objects
asset_object for objects in self.objects_per_module.values() for asset_object in objects
]

@cached_property
def objects_by_id(self) -> Dict[int, LoadableAssetTypes]:
def objects_by_id(self) -> Dict[int, LoadableDagsterObject]:
return {id(asset_object): asset_object for asset_object in self.flat_object_list}

@cached_property
def deduped_objects(self) -> Sequence[LoadableAssetTypes]:
def deduped_objects(self) -> Sequence[LoadableDagsterObject]:
return list(self.objects_by_id.values())

@cached_property
Expand All @@ -73,22 +74,24 @@ def source_assets(self) -> Sequence[SourceAsset]:
def module_name_by_id(self) -> Dict[int, str]:
return {
id(asset_object): module_name
for module_name, objects in self.assets_per_module.items()
for module_name, objects in self.objects_per_module.items()
for asset_object in objects
}

@cached_property
def objects_by_key(self) -> Mapping[AssetKey, Sequence[Union[SourceAsset, AssetsDefinition]]]:
def asset_objects_by_key(
self,
) -> Mapping[AssetKey, Sequence[Union[SourceAsset, AssetSpec, AssetsDefinition]]]:
objects_by_key = defaultdict(list)
for asset_object in self.flat_object_list:
if not isinstance(asset_object, KeyScopedAssetObjects):
if not isinstance(asset_object, RuntimeKeyScopedAssetObjectTypes):
continue
for key in key_iterator(asset_object):
objects_by_key[key].append(asset_object)
return objects_by_key

def _do_collision_detection(self) -> None:
for key, asset_objects in self.objects_by_key.items():
for key, asset_objects in self.asset_objects_by_key.items():
# If there is more than one asset_object in the list for a given key, and the objects do not refer to the same asset_object in memory, we have a collision.
num_distinct_objects_for_key = len(
set(id(asset_object) for asset_object in asset_objects)
Expand All @@ -101,14 +104,14 @@ def _do_collision_detection(self) -> None:
f"Asset key {key.to_user_string()} is defined multiple times. Definitions found in modules: {asset_objects_str}."
)

def to_post_load(self) -> "ResolvedAssetObjectList":
return ResolvedAssetObjectList(self.deduped_objects)
def get_object_list(self) -> "DagsterObjectsList":
return DagsterObjectsList(self.deduped_objects)


class ResolvedAssetObjectList:
class DagsterObjectsList:
def __init__(
self,
loaded_objects: Sequence[LoadableAssetTypes],
loaded_objects: Sequence[LoadableDagsterObject],
):
self.loaded_objects = loaded_objects

Expand Down Expand Up @@ -149,13 +152,15 @@ def cacheable_assets(self) -> Sequence[CacheableAssetsDefinition]:
]

def get_objects(
self, filter_fn: Callable[[LoadableAssetTypes], bool]
) -> Sequence[LoadableAssetTypes]:
return [asset for asset in self.loaded_objects if filter_fn(asset)]
self, filter_fn: Callable[[LoadableDagsterObject], bool]
) -> Sequence[LoadableDagsterObject]:
return [
dagster_object for dagster_object in self.loaded_objects if filter_fn(dagster_object)
]

def assets_with_loadable_prefix(
self, key_prefix: CoercibleToAssetKeyPrefix
) -> "ResolvedAssetObjectList":
) -> "DagsterObjectsList":
# There is a tricky edge case here where if a non-cacheable asset depends on a cacheable asset,
# and the assets are prefixed, the non-cacheable asset's dependency will not be prefixed since
# at prefix-time it is not known that its dependency is one of the cacheable assets.
Expand All @@ -174,34 +179,40 @@ def assets_with_loadable_prefix(
check_key_replacements = {
check_key: check_key.with_asset_key_prefix(key_prefix) for check_key in all_check_keys
}
for asset_object in self.loaded_objects:
if isinstance(asset_object, CacheableAssetsDefinition):
result_list.append(asset_object.with_prefix_for_all(key_prefix))
elif isinstance(asset_object, AssetsDefinition):
for dagster_object in self.loaded_objects:
if not isinstance(dagster_object, RuntimeAssetObjectTypes):
result_list.append(dagster_object)
if isinstance(dagster_object, CacheableAssetsDefinition):
result_list.append(dagster_object.with_prefix_for_all(key_prefix))
elif isinstance(dagster_object, AssetsDefinition):
result_list.append(
replace_keys_in_asset(asset_object, key_replacements, check_key_replacements)
replace_keys_in_asset(dagster_object, key_replacements, check_key_replacements)
)
else:
# We don't replace the key for SourceAssets.
result_list.append(asset_object)
return ResolvedAssetObjectList(result_list)
result_list.append(dagster_object)
return DagsterObjectsList(result_list)

def assets_with_source_prefix(
self, key_prefix: CoercibleToAssetKeyPrefix
) -> "ResolvedAssetObjectList":
) -> "DagsterObjectsList":
result_list = []
key_replacements = {
source_asset.key: source_asset.key.with_prefix(key_prefix)
for source_asset in self.source_assets
}
for asset_object in self.loaded_objects:
if isinstance(asset_object, KeyScopedAssetObjects):
for dagster_object in self.loaded_objects:
if not isinstance(dagster_object, RuntimeAssetObjectTypes):
result_list.append(dagster_object)
if isinstance(dagster_object, RuntimeKeyScopedAssetObjectTypes):
result_list.append(
replace_keys_in_asset(asset_object, key_replacements, check_key_replacements={})
replace_keys_in_asset(
dagster_object, key_replacements, check_key_replacements={}
)
)
else:
result_list.append(asset_object)
return ResolvedAssetObjectList(result_list)
result_list.append(dagster_object)
return DagsterObjectsList(result_list)

def with_attributes(
self,
Expand All @@ -211,15 +222,15 @@ def with_attributes(
freshness_policy: Optional[FreshnessPolicy],
automation_condition: Optional[AutomationCondition],
backfill_policy: Optional[BackfillPolicy],
) -> "ResolvedAssetObjectList":
assets_list = self.assets_with_loadable_prefix(key_prefix) if key_prefix else self
assets_list = (
assets_list.assets_with_source_prefix(source_key_prefix)
) -> "DagsterObjectsList":
dagster_object_list = self.assets_with_loadable_prefix(key_prefix) if key_prefix else self
dagster_object_list = (
dagster_object_list.assets_with_source_prefix(source_key_prefix)
if source_key_prefix
else assets_list
else dagster_object_list
)
return_list = []
for asset in assets_list.loaded_objects:
for asset in dagster_object_list.loaded_objects:
if isinstance(asset, AssetsDefinition):
new_asset = asset.map_asset_specs(
_spec_mapper_disallow_group_override(group_name, automation_condition)
Expand Down Expand Up @@ -250,7 +261,7 @@ def with_attributes(
backfill_policy=backfill_policy,
)
)
return ResolvedAssetObjectList(return_list)
return DagsterObjectsList(return_list)


def _spec_mapper_disallow_group_override(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition
from dagster._core.definitions.source_asset import SourceAsset

LoadableAssetTypes = Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]
KeyScopedAssetObjects = (AssetsDefinition, AssetSpec, SourceAsset)
LoadableAssetObject = Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]
LoadableDagsterObject = LoadableAssetObject # For now
RuntimeKeyScopedAssetObjectTypes = (AssetsDefinition, AssetSpec, SourceAsset)
RuntimeAssetObjectTypes = (AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition)
RuntimeDagsterObjectTypes = RuntimeAssetObjectTypes # For now


def find_objects_in_module_of_types(
Expand Down

0 comments on commit d4430a7

Please sign in to comment.