Skip to content

Commit

Permalink
Include specs in asset module loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Dec 19, 2024
1 parent c517f5e commit bc06a99
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

@patch_cereal_requests
def test_cereal():
assets, source_assets, _ = load_assets_from_modules([cereal])
assert materialize([*assets, *source_assets])
assets = load_assets_from_modules([cereal])
assert materialize(assets)
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

@patch_cereal_requests
def test_serial_asset_graph():
assets, source_assets, _ = load_assets_from_modules([serial_asset_graph])
assert materialize([*assets, *source_assets])
assets = load_assets_from_modules([serial_asset_graph])
assert materialize(assets)
16 changes: 10 additions & 6 deletions python_modules/dagster-test/dagster_test/toys/repo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import Sequence, cast

from dagster import ExperimentalWarning
from dagster._core.definitions.assets import AssetsDefinition
from dagster._time import get_current_timestamp

# squelch experimental warnings since we often include experimental things in toys for development
Expand Down Expand Up @@ -164,19 +166,21 @@ def column_schema_repository():
def table_metadata_repository():
from dagster_test.toys import table_metadata

return load_assets_from_modules([table_metadata])
return cast(Sequence[AssetsDefinition], load_assets_from_modules([table_metadata]))


@repository
def long_asset_keys_repository():
from dagster_test.toys import long_asset_keys

return load_assets_from_modules([long_asset_keys])
return cast(Sequence[AssetsDefinition], load_assets_from_modules([long_asset_keys]))


@repository # pyright: ignore[reportArgumentType]
@repository
def big_honkin_assets_repository():
return [load_assets_from_modules([big_honkin_asset_graph_module])]
return cast(
Sequence[AssetsDefinition], [load_assets_from_modules([big_honkin_asset_graph_module])]
)


@repository
Expand Down Expand Up @@ -208,11 +212,11 @@ def assets_with_sensors_repository():
def conditional_assets_repository():
from dagster_test.toys import conditional_assets

return load_assets_from_modules([conditional_assets])
return cast(Sequence[AssetsDefinition], load_assets_from_modules([conditional_assets]))


@repository
def data_versions_repository():
from dagster_test.toys import data_versions

return load_assets_from_modules([data_versions])
return cast(Sequence[AssetsDefinition], load_assets_from_modules([data_versions]))
3 changes: 1 addition & 2 deletions python_modules/dagster/dagster/_core/code_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def _load_target_from_module(module: ModuleType, fn_name: str, error_suffix: str
if fn_name == LOAD_ALL_ASSETS:
# LOAD_ALL_ASSETS is a special symbol that's returned when, instead of loading a particular
# attribute, we should load all the assets in the module.
module_assets, module_source_assets, _ = load_assets_from_modules([module])
return [*module_assets, *module_source_assets]
return load_assets_from_modules([module])
else:
if not hasattr(module, fn_name):
raise DagsterInvariantViolationError(f"{fn_name} not found {error_suffix}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def find_subclasses_in_module(
yield value


LoadableAssetTypes = Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]
KeyScopedAssetObjects = (AssetsDefinition, SourceAsset)
LoadableAssetTypes = Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]
KeyScopedAssetObjects = (AssetsDefinition, AssetSpec, SourceAsset)


class LoadedAssetsList:
Expand All @@ -81,7 +81,8 @@ def from_modules(cls, modules: Iterable[ModuleType]) -> "LoadedAssetsList":
{
module.__name__: list(
find_objects_in_module_of_types(
module, (AssetsDefinition, SourceAsset, CacheableAssetsDefinition)
module,
(AssetsDefinition, SourceAsset, CacheableAssetsDefinition, AssetSpec),
)
)
for module in modules
Expand Down Expand Up @@ -175,7 +176,7 @@ def _inner(spec: AssetSpec) -> AssetSpec:


def key_iterator(
asset: Union[AssetsDefinition, SourceAsset], included_targeted_keys: bool = False
asset: Union[AssetsDefinition, SourceAsset, AssetSpec], included_targeted_keys: bool = False
) -> Iterator[AssetKey]:
return (
iter(
Expand Down Expand Up @@ -203,7 +204,8 @@ def load_assets_from_modules(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets and source assets from the given modules.
Args:
Expand All @@ -226,6 +228,15 @@ def load_assets_from_modules(
Sequence[Union[AssetsDefinition, SourceAsset]]:
A list containing assets and source assets defined in the given modules.
"""

def _asset_filter(asset: LoadableAssetTypes) -> bool:
if isinstance(asset, AssetsDefinition):
# We don't load asset checks with asset module loaders.
return not has_only_asset_checks(asset)
if isinstance(asset, AssetSpec):
return include_specs
return True

return (
LoadedAssetsList.from_modules(modules)
.to_post_load()
Expand All @@ -245,7 +256,7 @@ def load_assets_from_modules(
backfill_policy, "backfill_policy", BackfillPolicy
),
)
.assets_only
.get_objects(_asset_filter)
)


Expand All @@ -258,7 +269,8 @@ def load_assets_from_current_module(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets, source assets, and cacheable assets from the module where
this function is called.
Expand Down Expand Up @@ -296,6 +308,7 @@ def load_assets_from_current_module(
),
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -309,6 +322,7 @@ def load_assets_from_package_module(
automation_condition: Optional[AutomationCondition] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
include_specs: bool = False,
) -> Sequence[LoadableAssetTypes]:
"""Constructs a list of assets and source assets that includes all asset
definitions, source assets, and cacheable assets in all sub-modules of the given package module.
Expand Down Expand Up @@ -344,6 +358,7 @@ def load_assets_from_package_module(
automation_condition=automation_condition,
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -356,7 +371,8 @@ def load_assets_from_package_name(
auto_materialize_policy: Optional[AutoMaterializePolicy] = None,
backfill_policy: Optional[BackfillPolicy] = None,
source_key_prefix: Optional[CoercibleToAssetKeyPrefix] = None,
) -> Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]]:
include_specs: bool = False,
) -> Sequence[Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]]:
"""Constructs a list of assets, source assets, and cacheable assets that includes all asset
definitions and source assets in all sub-modules of the given package.
Expand Down Expand Up @@ -389,6 +405,7 @@ def load_assets_from_package_name(
auto_materialize_policy=auto_materialize_policy,
backfill_policy=backfill_policy,
source_key_prefix=source_key_prefix,
include_specs=include_specs,
)


Expand All @@ -410,11 +427,17 @@ def find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]:


def replace_keys_in_asset(
asset: Union[AssetsDefinition, SourceAsset],
asset: Union[AssetsDefinition, AssetSpec, SourceAsset],
key_replacements: Mapping[AssetKey, AssetKey],
) -> Union[AssetsDefinition, SourceAsset]:
updated_object = (
asset.with_attributes(
) -> Union[AssetsDefinition, AssetSpec, SourceAsset]:
if isinstance(asset, SourceAsset):
return asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
if isinstance(asset, AssetSpec):
return asset.replace_attributes(
key=key_replacements.get(asset.key, asset.key),
)
else:
updated_object = asset.with_attributes(
output_asset_key_replacements={
key: key_replacements.get(key, key)
for key in key_iterator(asset, included_targeted_keys=True)
Expand All @@ -423,34 +446,31 @@ def replace_keys_in_asset(
key: key_replacements.get(key, key) for key in asset.keys_by_input_name.values()
},
)
if isinstance(asset, AssetsDefinition)
else asset.with_attributes(key=key_replacements.get(asset.key, asset.key))
)
if isinstance(asset, AssetChecksDefinition):
updated_object = cast(AssetsDefinition, updated_object)
updated_object = AssetChecksDefinition.create(
keys_by_input_name=updated_object.keys_by_input_name,
node_def=updated_object.op,
check_specs_by_output_name=updated_object.check_specs_by_output_name,
resource_defs=updated_object.resource_defs,
can_subset=updated_object.can_subset,
)
return updated_object
if isinstance(asset, AssetsDefinition) and has_only_asset_checks(asset):
updated_object = AssetChecksDefinition.create(
keys_by_input_name=updated_object.keys_by_input_name,
node_def=updated_object.op,
check_specs_by_output_name=updated_object.check_specs_by_output_name,
resource_defs=updated_object.resource_defs,
can_subset=updated_object.can_subset,
)
return updated_object


class ResolvedAssetObjectList:
def __init__(
self,
loaded_objects: Sequence[Union[AssetsDefinition, SourceAsset, CacheableAssetsDefinition]],
loaded_objects: Sequence[LoadableAssetTypes],
):
self.loaded_objects = loaded_objects

@cached_property
def assets_defs(self) -> Sequence[AssetsDefinition]:
def assets_defs_and_specs(self) -> Sequence[Union[AssetsDefinition, AssetSpec]]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, AssetsDefinition) and dagster_object.keys
if (isinstance(dagster_object, AssetsDefinition) and dagster_object.keys)
or isinstance(dagster_object, AssetSpec)
]

@cached_property
Expand All @@ -462,8 +482,10 @@ def checks_defs(self) -> Sequence[AssetChecksDefinition]:
]

@cached_property
def assets_and_checks_defs(self) -> Sequence[Union[AssetsDefinition, AssetChecksDefinition]]:
return [*self.assets_defs, *self.checks_defs]
def assets_defs_specs_and_checks_defs(
self,
) -> Sequence[Union[AssetsDefinition, AssetSpec, AssetChecksDefinition]]:
return [*self.assets_defs_and_specs, *self.checks_defs]

@cached_property
def source_assets(self) -> Sequence[SourceAsset]:
Expand All @@ -475,9 +497,10 @@ def cacheable_assets(self) -> Sequence[CacheableAssetsDefinition]:
asset for asset in self.loaded_objects if isinstance(asset, CacheableAssetsDefinition)
]

@cached_property
def assets_only(self) -> Sequence[LoadableAssetTypes]:
return [*self.source_assets, *self.assets_defs, *self.cacheable_assets]
def get_objects(
self, filter_fn: Callable[[LoadableAssetTypes], bool]
) -> Sequence[LoadableAssetTypes]:
return [asset for asset in self.loaded_objects if filter_fn(asset)]

def assets_with_loadable_prefix(
self, key_prefix: CoercibleToAssetKeyPrefix
Expand All @@ -489,7 +512,7 @@ def assets_with_loadable_prefix(
result_list = []
all_asset_keys = {
key
for asset_object in self.assets_and_checks_defs
for asset_object in self.assets_defs_specs_and_checks_defs
for key in key_iterator(asset_object, included_targeted_keys=True)
}
key_replacements = {key: key.with_prefix(key_prefix) for key in all_asset_keys}
Expand Down Expand Up @@ -554,6 +577,10 @@ def with_attributes(
return_list.append(
asset.with_attributes(group_name=group_name if group_name else asset.group_name)
)
elif isinstance(asset, AssetSpec):
return_list.append(
_spec_mapper_disallow_group_override(group_name, automation_condition)(asset)
)
else:
return_list.append(
asset.with_attributes_for_all(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dagster._serdes import whitelist_for_serdes

if TYPE_CHECKING:
from dagster._core.definitions.assets import AssetsDefinition, SourceAsset
from dagster._core.definitions.assets import AssetsDefinition, AssetSpec, SourceAsset
from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition

DEFAULT_SOURCE_FILE_KEY = "asset_definition"
Expand Down Expand Up @@ -86,11 +86,11 @@ def namespace(cls) -> str:


def _with_code_source_single_definition(
assets_def: Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition"],
) -> Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition"]:
assets_def: Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition", "AssetSpec"],
) -> Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition", "AssetSpec"]:
from dagster._core.definitions.assets import AssetsDefinition

# SourceAsset doesn't have an op definition to point to - cacheable assets
# SourceAsset and AssetSpec don't have an op definition to point to - cacheable assets
# will be supported eventually but are a bit trickier
if not isinstance(assets_def, AssetsDefinition):
return assets_def
Expand Down Expand Up @@ -242,8 +242,8 @@ def convert_local_path_to_git_path(
def _convert_local_path_to_git_path_single_definition(
base_git_url: str,
file_path_mapping: FilePathMapping,
assets_def: Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition"],
) -> Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition"]:
assets_def: Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition", "AssetSpec"],
) -> Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition", "AssetSpec"]:
from dagster._core.definitions.assets import AssetsDefinition

# SourceAsset doesn't have an op definition to point to - cacheable assets
Expand Down Expand Up @@ -293,11 +293,13 @@ def _build_gitlab_url(url: str, branch: str) -> str:

@experimental
def link_code_references_to_git(
assets_defs: Sequence[Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition"]],
assets_defs: Sequence[
Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition", "AssetSpec"]
],
git_url: str,
git_branch: str,
file_path_mapping: FilePathMapping,
) -> Sequence[Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition"]]:
) -> Sequence[Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition", "AssetSpec"]]:
"""Wrapper function which converts local file path code references to source control URLs
based on the provided source control URL and branch.
Expand Down Expand Up @@ -353,8 +355,10 @@ def link_code_references_to_git(

@experimental
def with_source_code_references(
assets_defs: Sequence[Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition"]],
) -> Sequence[Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition"]]:
assets_defs: Sequence[
Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition", "AssetSpec"]
],
) -> Sequence[Union["AssetsDefinition", "SourceAsset", "CacheableAssetsDefinition", "AssetSpec"]]:
"""Wrapper function which attaches local code reference metadata to the provided asset definitions.
This points to the filepath and line number where the asset body is defined.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def loadable_targets_from_loaded_module(module: ModuleType) -> Sequence[Loadable
)
)

module_assets, module_source_assets, _ = load_assets_from_modules([module])
if len(module_assets) > 0 or len(module_source_assets) > 0:
return [LoadableTarget(LOAD_ALL_ASSETS, [*module_assets, *module_source_assets])]
assets = load_assets_from_modules([module])
if len(assets) > 0:
return [LoadableTarget(LOAD_ALL_ASSETS, assets)]

raise DagsterInvariantViolationError(
"No Definitions, RepositoryDefinition, Job, Pipeline, Graph, or AssetsDefinition found in "
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dagster import AssetKey, SourceAsset, asset
from dagster._core.definitions.asset_spec import AssetSpec


@asset
Expand Down Expand Up @@ -29,4 +30,7 @@ def make_list_of_source_assets():
return [buddy_holly, jerry_lee_lewis]


top_level_spec = AssetSpec("top_level_spec")


list_of_assets_and_source_assets = [*make_list_of_assets(), *make_list_of_source_assets()]
Loading

0 comments on commit bc06a99

Please sign in to comment.