diff --git a/python_modules/dagster/dagster/_core/definitions/module_loaders/load_defs_from_module.py b/python_modules/dagster/dagster/_core/definitions/module_loaders/load_defs_from_module.py new file mode 100644 index 0000000000000..53b8344bbd0f8 --- /dev/null +++ b/python_modules/dagster/dagster/_core/definitions/module_loaders/load_defs_from_module.py @@ -0,0 +1,24 @@ +from types import ModuleType +from typing import Optional + +from dagster._core.definitions.definitions_class import Definitions +from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects +from dagster._core.definitions.module_loaders.utils import ( + ExecutorObject, + LoggerDefinitionKeyMapping, + ResourceDefinitionMapping, +) + + +def load_definitions_from_module( + module: ModuleType, + resources: Optional[ResourceDefinitionMapping] = None, + loggers: Optional[LoggerDefinitionKeyMapping] = None, + executor: Optional[ExecutorObject] = None, +) -> Definitions: + return Definitions( + **ModuleScopedDagsterObjects.from_modules([module]).get_object_list().to_definitions_args(), + resources=resources, + loggers=loggers, + executor=executor, + ) diff --git a/python_modules/dagster/dagster/_core/definitions/module_loaders/object_list.py b/python_modules/dagster/dagster/_core/definitions/module_loaders/object_list.py index 56caf8de99e68..d503128e7e603 100644 --- a/python_modules/dagster/dagster/_core/definitions/module_loaders/object_list.py +++ b/python_modules/dagster/dagster/_core/definitions/module_loaders/object_list.py @@ -1,7 +1,7 @@ from collections import defaultdict from functools import cached_property from types import ModuleType -from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence, Union, cast +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union, cast from dagster._core.definitions.asset_checks import AssetChecksDefinition, has_only_asset_checks from dagster._core.definitions.asset_key import AssetKey, CoercibleToAssetKeyPrefix @@ -15,6 +15,7 @@ from dagster._core.definitions.freshness_policy import FreshnessPolicy from dagster._core.definitions.module_loaders.utils import ( JobDefinitionObject, + LoadableAssetObject, LoadableDagsterObject, RuntimeAssetObjectTypes, RuntimeDagsterObjectTypes, @@ -208,6 +209,38 @@ def cacheable_assets(self) -> Sequence[CacheableAssetsDefinition]: asset for asset in self.loaded_objects if isinstance(asset, CacheableAssetsDefinition) ] + @cached_property + def sensors(self) -> Sequence[SensorDefinition]: + return [ + dagster_object + for dagster_object in self.loaded_objects + if isinstance(dagster_object, SensorDefinition) + ] + + @cached_property + def schedules(self) -> Sequence[ScheduleDefinitionObject]: + return [ + dagster_object + for dagster_object in self.loaded_objects + if isinstance(dagster_object, RuntimeScheduleObjectTypes) + ] + + @cached_property + def jobs(self) -> Sequence[JobDefinitionObject]: + return [ + dagster_object + for dagster_object in self.loaded_objects + if isinstance(dagster_object, RuntimeJobObjectTypes) + ] + + @cached_property + def assets(self) -> Sequence[LoadableAssetObject]: + return [ + *self.assets_defs_and_specs, + *self.source_assets, + *self.cacheable_assets, + ] + def get_objects( self, filter_fn: Callable[[LoadableDagsterObject], bool] ) -> Sequence[LoadableDagsterObject]: @@ -304,3 +337,12 @@ def with_attributes( else updated_object ) return DagsterObjectsList(return_list) + + def to_definitions_args(self) -> Mapping[str, Any]: + return { + "assets": self.assets, + "asset_checks": self.checks_defs, + "sensors": self.sensors, + "schedules": self.schedules, + "jobs": self.jobs, + } diff --git a/python_modules/dagster/dagster_tests/definitions_tests/module_loader_tests/test_module_loaders.py b/python_modules/dagster/dagster_tests/definitions_tests/module_loader_tests/test_module_loaders.py index ccdbd5492a678..3fee6021fab99 100644 --- a/python_modules/dagster/dagster_tests/definitions_tests/module_loader_tests/test_module_loaders.py +++ b/python_modules/dagster/dagster_tests/definitions_tests/module_loader_tests/test_module_loaders.py @@ -1,10 +1,16 @@ +import logging from contextlib import contextmanager from types import ModuleType -from typing import Any, Mapping, Sequence, Type +from typing import Any, Mapping, Sequence, Type, cast import dagster as dg import pytest +from dagster._core.definitions.definitions_class import Definitions +from dagster._core.definitions.module_loaders.load_defs_from_module import ( + load_definitions_from_module, +) from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects +from dagster._core.definitions.module_loaders.utils import LoadableDagsterObject from dagster._record import record @@ -42,6 +48,16 @@ def my_check() -> dg.AssetCheckResult: return my_check +def all_loadable_objects_from_defs(defs: Definitions) -> Sequence[LoadableDagsterObject]: + return [ + *(defs.assets or []), + *(defs.sensors or []), + *(defs.schedules or []), + *(defs.asset_checks or []), + *(defs.jobs or []), + ] + + @contextmanager def optional_pytest_raise(error_expected: bool, exception_cls: Type[Exception]): if error_expected: @@ -156,3 +172,54 @@ def test_collision_detection(objects: Mapping[str, Any], error_expected: bool) - obj_list = ModuleScopedDagsterObjects.from_modules([module_fake]).get_object_list() obj_ids = {id(obj) for obj in objects.values()} assert len(obj_list.loaded_objects) == len(obj_ids) + + +@pytest.mark.parametrize(**ModuleScopeTestSpec.as_parametrize_kwargs(MODULE_TEST_SPECS)) +def test_load_from_definitions(objects: Mapping[str, Any], error_expected: bool) -> None: + module_fake = build_module_fake("fake", objects) + with optional_pytest_raise( + error_expected=error_expected, exception_cls=dg.DagsterInvalidDefinitionError + ): + defs = load_definitions_from_module(module_fake) + obj_ids = {id(obj) for obj in all_loadable_objects_from_defs(defs)} + expected_obj_ids = {id(obj) for obj in objects.values()} + assert len(obj_ids) == len(expected_obj_ids) + + +def test_load_with_resources() -> None: + @dg.resource + def my_resource(): ... + + module_fake = build_module_fake("foo", {"my_resource": my_resource}) + defs = load_definitions_from_module(module_fake) + assert len(all_loadable_objects_from_defs(defs)) == 0 + assert len(defs.resources or {}) == 0 + defs = load_definitions_from_module(module_fake, resources={"foo": my_resource}) + assert len(defs.resources or {}) == 1 + + +def test_load_with_logger_defs() -> None: + @dg.logger(config_schema={}) + def my_logger(init_context) -> logging.Logger: ... + + module_fake = build_module_fake("foo", {"my_logger": my_logger}) + defs = load_definitions_from_module(module_fake) + assert len(all_loadable_objects_from_defs(defs)) == 0 + assert len(defs.resources or {}) == 0 + defs = load_definitions_from_module(module_fake, resources={"foo": my_logger}) + assert len(defs.resources or {}) == 1 + + +def test_load_with_executor() -> None: + @dg.executor(name="my_executor") + def my_executor(init_context) -> dg.Executor: ... + + module_fake = build_module_fake("foo", {"my_executor": my_executor}) + defs = load_definitions_from_module(module_fake) + assert len(all_loadable_objects_from_defs(defs)) == 0 + assert defs.executor is None + defs = load_definitions_from_module(module_fake, executor=my_executor) + assert ( + defs.executor is not None + and cast(dg.ExecutorDefinition, defs.executor).name == "my_executor" + )