Skip to content

Commit

Permalink
Include specs in asset module loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Dec 18, 2024
1 parent cd99152 commit 0291fcc
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

@patch_cereal_requests
def test_cereal():
assets, source_assets, _ = load_assets_from_modules([cereal])
assert materialize([*assets, *source_assets])
assets = load_assets_from_modules([cereal])
assert materialize(assets)
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

@patch_cereal_requests
def test_serial_asset_graph():
assets, source_assets, _ = load_assets_from_modules([serial_asset_graph])
assert materialize([*assets, *source_assets])
assets = load_assets_from_modules([serial_asset_graph])
assert materialize(assets)
3 changes: 1 addition & 2 deletions python_modules/dagster/dagster/_core/code_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def _load_target_from_module(module: ModuleType, fn_name: str, error_suffix: str
if fn_name == LOAD_ALL_ASSETS:
# LOAD_ALL_ASSETS is a special symbol that's returned when, instead of loading a particular
# attribute, we should load all the assets in the module.
module_assets, module_source_assets, _ = load_assets_from_modules([module])
return [*module_assets, *module_source_assets]
return load_assets_from_modules([module])
else:
if not hasattr(module, fn_name):
raise DagsterInvariantViolationError(f"{fn_name} not found {error_suffix}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def find_subclasses_in_module(
yield value


LoadableAssetTypes = Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]
KeyScopedAssetObjects = (AssetsDefinition, SourceAsset)
LoadableAssetTypes = Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]
KeyScopedAssetObjects = (AssetsDefinition, AssetSpec, SourceAsset)


class LoadedAssetsList:
Expand All @@ -81,7 +81,8 @@ def from_modules(cls, modules: Iterable[ModuleType]) -> "LoadedAssetsList":
{
module.__name__: list(
find_objects_in_module_of_types(
module, (AssetsDefinition, SourceAsset, CacheableAssetsDefinition)
module,
(AssetsDefinition, SourceAsset, CacheableAssetsDefinition, AssetSpec),
)
)
for module in modules
Expand Down Expand Up @@ -175,7 +176,7 @@ def _inner(spec: AssetSpec) -> AssetSpec:


def key_iterator(
asset: Union[AssetsDefinition, SourceAsset], included_targeted_keys: bool = False
asset: Union[AssetsDefinition, SourceAsset, AssetSpec], included_targeted_keys: bool = False
) -> Iterator[AssetKey]:
return (
iter(
Expand Down Expand Up @@ -203,7 +204,8 @@ def load_assets_from_modules(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets and source assets from the given modules.
Args:
Expand All @@ -226,6 +228,15 @@ def load_assets_from_modules(
Sequence[Union[AssetsDefinition, SourceAsset]]:
A list containing assets and source assets defined in the given modules.
"""

def _asset_filter(asset: LoadableAssetTypes) -> bool:
if isinstance(asset, AssetsDefinition):
# We don't load asset checks with asset module loaders.
return not has_only_asset_checks(asset)
if isinstance(asset, AssetSpec):
return include_specs
return True

return (
LoadedAssetsList.from_modules(modules)
.to_post_load()
Expand All @@ -245,7 +256,7 @@ def load_assets_from_modules(
backfill_policy, "backfill_policy", BackfillPolicy
),
)
.assets_only
.get_objects(_asset_filter)
)


Expand All @@ -258,7 +269,8 @@ def load_assets_from_current_module(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets, source assets, and cacheable assets from the module where
this function is called.
Expand Down Expand Up @@ -296,6 +308,7 @@ def load_assets_from_current_module(
),
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -309,6 +322,7 @@ def load_assets_from_package_module(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
include_specs: bool = False,
) -> Sequence[LoadableAssetTypes]:
"""Constructs a list of assets and source assets that includes all asset
definitions, source assets, and cacheable assets in all sub-modules of the given package module.
Expand Down Expand Up @@ -344,6 +358,7 @@ def load_assets_from_package_module(
automation_condition=automation_condition,
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -356,7 +371,8 @@ def load_assets_from_package_name(
auto_materialize_policy: Optional[AutoMaterializePolicy] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets, source assets, and cacheable assets that includes all asset
definitions and source assets in all sub-modules of the given package.
Expand Down Expand Up @@ -389,6 +405,7 @@ def load_assets_from_package_name(
auto_materialize_policy=auto_materialize_policy,
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -410,11 +427,17 @@ def find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]:


def replace_keys_in_asset(
asset: Union[AssetsDefinition, SourceAsset],
asset: Union[AssetsDefinition, AssetSpec, SourceAsset],
key_replacements: Mapping[AssetKey, AssetKey],
) -> Union[AssetsDefinition, SourceAsset]:
updated_object = (
asset.with_attributes(
) -> Union[AssetsDefinition, AssetSpec, SourceAsset]:
if isinstance(asset, SourceAsset):
return asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
if isinstance(asset, AssetSpec):
return asset.replace_attributes(
key=key_replacements.get(asset.key, asset.key),
)
else:
updated_object = asset.with_attributes(
output_asset_key_replacements={
key: key_replacements.get(key, key)
for key in key_iterator(asset, included_targeted_keys=True)
Expand All @@ -423,34 +446,31 @@ def replace_keys_in_asset(
key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values()
},
)
if isinstance(asset, AssetsDefinition)
else asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
)
if isinstance(asset, AssetChecksDefinition):
updated_object = cast(AssetsDefinition, updated_object)
updated_object = AssetChecksDefinition.create(
keys_by_input_name=updated_object.keys_by_input_name,
node_def=updated_object.op,
check_specs_by_output_name=updated_object.check_specs_by_output_name,
resource_defs=updated_object.resource_defs,
can_subset=updated_object.can_subset,
)
return updated_object
if isinstance(asset, AssetsDefinition) and has_only_asset_checks(asset):
updated_object = AssetChecksDefinition.create(
keys_by_input_name=updated_object.keys_by_input_name,
node_def=updated_object.op,
check_specs_by_output_name=updated_object.check_specs_by_output_name,
resource_defs=updated_object.resource_defs,
can_subset=updated_object.can_subset,
)
return updated_object


class ResolvedAssetObjectList:
def __init__(
self,
loaded_objects: Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]],
loaded_objects: Sequence[LoadableAssetTypes],
):
self.loaded_objects = loaded_objects

@cached_property
def assets_defs(self) -> Sequence[AssetsDefinition]:
def assets_defs_and_specs(self) -> Sequence[Union[AssetsDefinition, AssetSpec]]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, AssetsDefinition) and dagster_object.keys
if (isinstance(dagster_object, AssetsDefinition) and dagster_object.keys)
or isinstance(dagster_object, AssetSpec)
]

@cached_property
Expand All @@ -462,8 +482,10 @@ def checks_defs(self) -> Sequence[AssetChecksDefinition]:
]

@cached_property
def assets_and_checks_defs(self) -> Sequence[Union[AssetsDefinition, AssetChecksDefinition]]:
return [*self.assets_defs, *self.checks_defs]
def assets_defs_specs_and_checks_defs(
self,
) -> Sequence[Union[AssetsDefinition, AssetSpec, AssetChecksDefinition]]:
return [*self.assets_defs_and_specs, *self.checks_defs]

@cached_property
def source_assets(self) -> Sequence[SourceAsset]:
Expand All @@ -475,9 +497,10 @@ def cacheable_assets(self) -> Sequence[CacheableAssetsDefinition]:
asset for asset in self.loaded_objects if isinstance(asset, CacheableAssetsDefinition)
]

@cached_property
def assets_only(self) -> Sequence[LoadableAssetTypes]:
return [*self.source_assets, *self.assets_defs, *self.cacheable_assets]
def get_objects(
self, filter_fn: Callable[[LoadableAssetTypes], bool]
) -> Sequence[LoadableAssetTypes]:
return [asset for asset in self.loaded_objects if filter_fn(asset)]

def assets_with_loadable_prefix(
self, key_prefix: CoercibleToAssetKeyPrefix
Expand All @@ -489,7 +512,7 @@ def assets_with_loadable_prefix(
result_list = []
all_asset_keys = {
key
for asset_object in self.assets_and_checks_defs
for asset_object in self.assets_defs_specs_and_checks_defs
for key in key_iterator(asset_object, included_targeted_keys=True)
}
key_replacements = {key: key.with_prefix(key_prefix) for key in all_asset_keys}
Expand Down Expand Up @@ -554,6 +577,10 @@ def with_attributes(
return_list.append(
asset.with_attributes(group_name=group_name if group_name else asset.group_name)
)
elif isinstance(asset, AssetSpec):
return_list.append(
_spec_mapper_disallow_group_override(group_name, automation_condition)(asset)
)
else:
return_list.append(
asset.with_attributes_for_all(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def loadable_targets_from_loaded_module(module: ModuleType) -> Sequence[Loadable
)
)

module_assets, module_source_assets, _ = load_assets_from_modules([module])
if len(module_assets) > 0 or len(module_source_assets) > 0:
return [LoadableTarget(LOAD_ALL_ASSETS, [*module_assets, *module_source_assets])]
assets = load_assets_from_modules([module])
if len(assets) > 0:
return [LoadableTarget(LOAD_ALL_ASSETS, assets)]

raise DagsterInvariantViolationError(
"No Definitions, RepositoryDefinition, Job, Pipeline, Graph, or AssetsDefinition found in "
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dagster import AssetKey, SourceAsset, asset
from dagster._core.definitions.asset_spec import AssetSpec


@asset
Expand Down Expand Up @@ -29,4 +30,7 @@ def make_list_of_source_assets():
return [buddy_holly, jerry_lee_lewis]


top_level_spec = AssetSpec("top_level_spec")


list_of_assets_and_source_assets = [*make_list_of_assets(), *make_list_of_source_assets()]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dagster import AssetKey, SourceAsset, asset, graph_asset, op
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.metadata import (
CodeReferencesMetadataSet,
CodeReferencesMetadataValue,
Expand Down Expand Up @@ -41,3 +42,6 @@ def multiply_by_two(input_num):
@graph_asset
def graph_backed_asset():
return multiply_by_two(one())


my_spec = AssetSpec("my_asset_spec")
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@

# {path to module}:{path to file relative to module root}:{line number}
EXPECTED_ORIGINS = {
"james_brown": DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/__init__.py:12",
"james_brown": DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/__init__.py:13",
"chuck_berry": (
DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/module_with_assets.py:18"
DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/module_with_assets.py:19"
),
"little_richard": (DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/__init__.py:4"),
"fats_domino": DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/__init__.py:16",
"little_richard": (DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/__init__.py:5"),
"fats_domino": DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/__init__.py:17",
"miles_davis": (
DAGSTER_PACKAGE_PATH
+ PATH_IN_PACKAGE
+ "asset_package/asset_subpackage/another_module_with_assets.py:6"
),
"graph_backed_asset": (
DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/module_with_assets.py:41"
DAGSTER_PACKAGE_PATH + PATH_IN_PACKAGE + "asset_package/module_with_assets.py:42"
),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from dagster._core.definitions.decorators.asset_check_decorator import asset_check
from dagster._core.definitions.dependency import NodeHandle, NodeInvocation
from dagster._core.definitions.executor_definition import in_process_executor
from dagster._core.definitions.load_assets_from_modules import prefix_assets
from dagster._core.errors import DagsterInvalidSubsetError
from dagster._core.execution.api import execute_run_iterator
from dagster._core.snap import DependencyStructureIndex
Expand Down Expand Up @@ -2362,7 +2361,15 @@ def test_asset_group_build_subset_job(job_selection, expected_assets, use_multi,
all_assets = _get_assets_defs(use_multi=use_multi, allow_subset=use_multi)
# apply prefixes
for prefix in reversed(prefixes or []):
all_assets, _ = prefix_assets(all_assets, prefix, [], None)
all_assets = [
assets_def.with_attributes(
input_asset_key_replacements={
k: k.with_prefix(prefix) for k in assets_def.keys_by_input_name.values()
},
output_asset_key_replacements={k: k.with_prefix(prefix) for k in assets_def.keys},
)
for assets_def in all_assets
]

defs = Definitions(
# for these, if we have multi assets, we'll always allow them to be subset
Expand Down
Loading

0 comments on commit 0291fcc

Please sign in to comment.