Skip to content

Commit

Permalink
Include specs in asset module loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Dec 18, 2024
1 parent 5be4d04 commit 2219a42
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def find_subclasses_in_module(
yield value


LoadableAssetTypes = Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]
KeyScopedAssetObjects = (AssetsDefinition, SourceAsset)
LoadableAssetTypes = Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]
KeyScopedAssetObjects = (AssetsDefinition, AssetSpec, SourceAsset)


class LoadedAssetsList:
Expand All @@ -81,7 +81,8 @@ def from_modules(cls, modules: Iterable[ModuleType]) -> "LoadedAssetsList":
{
module.__name__: list(
find_objects_in_module_of_types(
module, (AssetsDefinition, SourceAsset, CacheableAssetsDefinition)
module,
(AssetsDefinition, SourceAsset, CacheableAssetsDefinition, AssetSpec),
)
)
for module in modules
Expand Down Expand Up @@ -175,7 +176,7 @@ def _inner(spec: AssetSpec) -> AssetSpec:


def key_iterator(
asset: Union[AssetsDefinition, SourceAsset], included_targeted_keys: bool = False
asset: Union[AssetsDefinition, SourceAsset, AssetSpec], included_targeted_keys: bool = False
) -> Iterator[AssetKey]:
return (
iter(
Expand Down Expand Up @@ -203,7 +204,8 @@ def load_assets_from_modules(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets and source assets from the given modules.
Args:
Expand All @@ -226,6 +228,14 @@ def load_assets_from_modules(
Sequence[Union[AssetsDefinition, SourceAsset]]:
A list containing assets and source assets defined in the given modules.
"""

def _asset_filter(asset: LoadableAssetTypes) -> bool:
if isinstance(asset, AssetsDefinition) and not has_only_asset_checks(asset):
return True
if isinstance(asset, AssetSpec):
return include_specs
return True

return (
LoadedAssetsList.from_modules(modules)
.to_post_load()
Expand All @@ -245,7 +255,7 @@ def load_assets_from_modules(
backfill_policy, "backfill_policy", BackfillPolicy
),
)
.assets_only
.get_objects(_asset_filter)
)


Expand All @@ -258,7 +268,8 @@ def load_assets_from_current_module(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets, source assets, and cacheable assets from the module where
this function is called.
Expand Down Expand Up @@ -296,6 +307,7 @@ def load_assets_from_current_module(
),
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -309,6 +321,7 @@ def load_assets_from_package_module(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
include_specs: bool = False,
) -> Sequence[LoadableAssetTypes]:
"""Constructs a list of assets and source assets that includes all asset
definitions, source assets, and cacheable assets in all sub-modules of the given package module.
Expand Down Expand Up @@ -344,6 +357,7 @@ def load_assets_from_package_module(
automation_condition=automation_condition,
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -356,7 +370,8 @@ def load_assets_from_package_name(
auto_materialize_policy: Optional[AutoMaterializePolicy] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets, source assets, and cacheable assets that includes all asset
definitions and source assets in all sub-modules of the given package.
Expand Down Expand Up @@ -389,6 +404,7 @@ def load_assets_from_package_name(
auto_materialize_policy=auto_materialize_policy,
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -410,11 +426,17 @@ def find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]:


def replace_keys_in_asset(
asset: Union[AssetsDefinition, SourceAsset],
asset: Union[AssetsDefinition, AssetSpec, SourceAsset],
key_replacements: Mapping[AssetKey, AssetKey],
) -> Union[AssetsDefinition, SourceAsset]:
updated_object = (
asset.with_attributes(
) -> Union[AssetsDefinition, AssetSpec, SourceAsset]:
if isinstance(asset, SourceAsset):
return asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
if isinstance(asset, AssetSpec):
return asset.replace_attributes(
key=key_replacements.get(asset.key, asset.key),
)
else:
updated_object = asset.with_attributes(
output_asset_key_replacements={
key: key_replacements.get(key, key)
for key in key_iterator(asset, included_targeted_keys=True)
Expand All @@ -423,34 +445,31 @@ def replace_keys_in_asset(
key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values()
},
)
if isinstance(asset, AssetsDefinition)
else asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
)
if isinstance(asset, AssetChecksDefinition):
updated_object = cast(AssetsDefinition, updated_object)
updated_object = AssetChecksDefinition.create(
keys_by_input_name=updated_object.keys_by_input_name,
node_def=updated_object.op,
check_specs_by_output_name=updated_object.check_specs_by_output_name,
resource_defs=updated_object.resource_defs,
can_subset=updated_object.can_subset,
)
return updated_object
if isinstance(asset, AssetsDefinition) and has_only_asset_checks(asset):
updated_object = AssetChecksDefinition.create(
keys_by_input_name=updated_object.keys_by_input_name,
node_def=updated_object.op,
check_specs_by_output_name=updated_object.check_specs_by_output_name,
resource_defs=updated_object.resource_defs,
can_subset=updated_object.can_subset,
)
return updated_object


class ResolvedAssetObjectList:
def __init__(
self,
loaded_objects: Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]],
loaded_objects: Sequence[LoadableAssetTypes],
):
self.loaded_objects = loaded_objects

@cached_property
def assets_defs(self) -> Sequence[AssetsDefinition]:
def assets_defs_and_specs(self) -> Sequence[Union[AssetsDefinition, AssetSpec]]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, AssetsDefinition) and dagster_object.keys
if (isinstance(dagster_object, AssetsDefinition) and dagster_object.keys)
or isinstance(dagster_object, AssetSpec)
]

@cached_property
Expand All @@ -462,8 +481,10 @@ def checks_defs(self) -> Sequence[AssetChecksDefinition]:
]

@cached_property
def assets_and_checks_defs(self) -> Sequence[Union[AssetsDefinition, AssetChecksDefinition]]:
return [*self.assets_defs, *self.checks_defs]
def assets_defs_specs_and_checks_defs(
self,
) -> Sequence[Union[AssetsDefinition, AssetSpec, AssetChecksDefinition]]:
return [*self.assets_defs_and_specs, *self.checks_defs]

@cached_property
def source_assets(self) -> Sequence[SourceAsset]:
Expand All @@ -475,9 +496,10 @@ def cacheable_assets(self) -> Sequence[CacheableAssetsDefinition]:
asset for asset in self.loaded_objects if isinstance(asset, CacheableAssetsDefinition)
]

@cached_property
def assets_only(self) -> Sequence[LoadableAssetTypes]:
return [*self.source_assets, *self.assets_defs, *self.cacheable_assets]
def get_objects(
self, filter_fn: Callable[[LoadableAssetTypes], bool]
) -> Sequence[LoadableAssetTypes]:
return [asset for asset in self.loaded_objects if filter_fn(asset)]

def assets_with_loadable_prefix(
self, key_prefix: CoercibleToAssetKeyPrefix
Expand All @@ -489,7 +511,7 @@ def assets_with_loadable_prefix(
result_list = []
all_asset_keys = {
key
for asset_object in self.assets_and_checks_defs
for asset_object in self.assets_defs_specs_and_checks_defs
for key in key_iterator(asset_object, included_targeted_keys=True)
}
key_replacements = {key: key.with_prefix(key_prefix) for key in all_asset_keys}
Expand Down Expand Up @@ -554,6 +576,10 @@ def with_attributes(
return_list.append(
asset.with_attributes(group_name=group_name if group_name else asset.group_name)
)
elif isinstance(asset, AssetSpec):
return_list.append(
_spec_mapper_disallow_group_override(group_name, automation_condition)(asset)
)
else:
return_list.append(
asset.with_attributes_for_all(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dagster import AssetKey, SourceAsset, asset
from dagster._core.definitions.asset_spec import AssetSpec


@asset
Expand Down Expand Up @@ -29,4 +30,7 @@ def make_list_of_source_assets():
return [buddy_holly, jerry_lee_lewis]


top_level_spec = AssetSpec("top_level_spec")


list_of_assets_and_source_assets = [*make_list_of_assets(), *make_list_of_source_assets()]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dagster import AssetKey, SourceAsset, asset, graph_asset, op
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.metadata import (
CodeReferencesMetadataSet,
CodeReferencesMetadataValue,
Expand Down Expand Up @@ -41,3 +42,6 @@ def multiply_by_two(input_num):
@graph_asset
def graph_backed_asset():
return multiply_by_two(one())


my_spec = AssetSpec("my_asset_spec")
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
load_assets_from_package_module,
load_assets_from_package_name,
)
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition

Expand Down Expand Up @@ -101,6 +102,15 @@ def test_load_assets_from_package_name():

assert assets_1 == assets_2

assets_3 = load_assets_from_package_name(asset_package.__name__, include_specs=True)
assert len(assets_3) == 13

assert next(
iter(
a for a in assets_3 if isinstance(a, AssetSpec) and a.key == AssetKey("top_level_spec")
)
)


def test_load_assets_from_package_module():
from dagster_tests.asset_defs_tests import asset_package
Expand All @@ -117,6 +127,15 @@ def test_load_assets_from_package_module():

assert assets_1 == assets_2

assets_3 = load_assets_from_package_name(asset_package.__name__, include_specs=True)
assert len(assets_3) == 13

assert next(
iter(
a for a in assets_3 if isinstance(a, AssetSpec) and a.key == AssetKey("top_level_spec")
)
)


def test_load_assets_from_modules(monkeypatch):
from dagster_tests.asset_defs_tests import asset_package
Expand Down Expand Up @@ -145,6 +164,19 @@ def little_richard():
):
load_assets_from_modules([asset_package, module_with_assets])

# Create an AssetsDefinition with an identical spec to that in the module
with monkeypatch.context() as m:

@asset
def top_level_spec():
pass

m.setattr(asset_package, "top_level_spec_same_assets_def", top_level_spec, raising=False)
with pytest.raises(
DagsterInvalidDefinitionError,
):
load_assets_from_modules([asset_package, module_with_assets], include_specs=True)


@asset(group_name="my_group")
def asset_in_current_module():
Expand All @@ -153,12 +185,22 @@ def asset_in_current_module():

source_asset_in_current_module = SourceAsset(AssetKey("source_asset_in_current_module"))

spec_in_current_module = AssetSpec("spec_in_current_module")


def test_load_assets_from_current_module():
assets = load_assets_from_current_module()
assets = [get_unique_asset_identifier(asset) for asset in assets]
assert assets == ["asset_in_current_module", AssetKey("source_asset_in_current_module")]
assert set(assets) == {"asset_in_current_module", AssetKey("source_asset_in_current_module")}
assert len(assets) == 2
assets = load_assets_from_current_module(include_specs=True)
assets = [get_unique_asset_identifier(asset) for asset in assets]
assert len(assets) == 3
assert set(assets) == {
"asset_in_current_module",
AssetKey("source_asset_in_current_module"),
AssetKey("spec_in_current_module"),
}


def test_load_assets_from_modules_with_group_name():
Expand All @@ -176,7 +218,8 @@ def test_load_assets_from_modules_with_group_name():

def test_respect_existing_groups():
assets = load_assets_from_current_module()
assert assets[0].group_names_by_key.get(AssetKey("asset_in_current_module")) == "my_group" # pyright: ignore[reportAttributeAccessIssue]
assets_def = next(iter(a for a in assets if isinstance(a, AssetsDefinition)))
assert assets_def.group_names_by_key.get(AssetKey("asset_in_current_module")) == "my_group" # pyright: ignore[reportAttributeAccessIssue]

with pytest.raises(DagsterInvalidDefinitionError):
load_assets_from_current_module(group_name="yay")
Expand Down

0 comments on commit 2219a42

Please sign in to comment.