Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[module-loaders] [rfc] Definitions from module loader #26546

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from types import ModuleType
from typing import Optional

from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects
from dagster._core.definitions.module_loaders.utils import (
ExecutorObject,
LoggerDefinitionKeyMapping,
ResourceDefinitionMapping,
)


def load_definitions_from_module(
module: ModuleType,
resources: Optional[ResourceDefinitionMapping] = None,
loggers: Optional[LoggerDefinitionKeyMapping] = None,
executor: Optional[ExecutorObject] = None,
) -> Definitions:
return Definitions(
**ModuleScopedDagsterObjects.from_modules([module]).get_object_list().to_definitions_args(),
resources=resources,
loggers=loggers,
executor=executor,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from functools import cached_property
from types import ModuleType
from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence, Union, cast
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union, cast

from dagster._core.definitions.asset_checks import AssetChecksDefinition, has_only_asset_checks
from dagster._core.definitions.asset_key import AssetKey, CoercibleToAssetKeyPrefix
Expand All @@ -15,6 +15,7 @@
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.module_loaders.utils import (
JobDefinitionObject,
LoadableAssetObject,
LoadableDagsterObject,
RuntimeAssetObjectTypes,
RuntimeDagsterObjectTypes,
Expand Down Expand Up @@ -210,6 +211,38 @@ def cacheable_assets(self) -> Sequence[CacheableAssetsDefinition]:
asset for asset in self.loaded_objects if isinstance(asset, CacheableAssetsDefinition)
]

@cached_property
def sensors(self) -> Sequence[SensorDefinition]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, SensorDefinition)
]

@cached_property
def schedules(self) -> Sequence[ScheduleDefinitionObject]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, RuntimeScheduleObjectTypes)
]

@cached_property
def jobs(self) -> Sequence[JobDefinitionObject]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, RuntimeJobObjectTypes)
]

@cached_property
def assets(self) -> Sequence[LoadableAssetObject]:
return [
*self.assets_defs_and_specs,
*self.source_assets,
*self.cacheable_assets,
]

def get_objects(
self, filter_fn: Callable[[LoadableDagsterObject], bool]
) -> Sequence[LoadableDagsterObject]:
Expand Down Expand Up @@ -324,6 +357,15 @@ def with_attributes(
)
return DagsterObjectsList(return_list)

def to_definitions_args(self) -> Mapping[str, Any]:
return {
"assets": self.assets,
"asset_checks": self.checks_defs,
"sensors": self.sensors,
"schedules": self.schedules,
"jobs": self.jobs,
}


def _spec_mapper_disallow_group_override(
group_name: Optional[str], automation_condition: Optional[AutomationCondition]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import logging
from contextlib import contextmanager
from types import ModuleType
from typing import Any, Mapping, Sequence, Type
from typing import Any, Mapping, Sequence, Type, cast

import dagster as dg
import pytest
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.module_loaders.load_defs_from_module import (
load_definitions_from_module,
)
from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects
from dagster._core.definitions.module_loaders.utils import LoadableDagsterObject
from dagster._record import record


Expand Down Expand Up @@ -42,6 +48,16 @@ def my_check() -> dg.AssetCheckResult:
return my_check


def all_loadable_objects_from_defs(defs: Definitions) -> Sequence[LoadableDagsterObject]:
return [
*(defs.assets or []),
*(defs.sensors or []),
*(defs.schedules or []),
*(defs.asset_checks or []),
*(defs.jobs or []),
]


@contextmanager
def optional_pytest_raise(error_expected: bool, exception_cls: Type[Exception]):
if error_expected:
Expand Down Expand Up @@ -156,3 +172,54 @@ def test_collision_detection(objects: Mapping[str, Any], error_expected: bool) -
obj_list = ModuleScopedDagsterObjects.from_modules([module_fake]).get_object_list()
obj_ids = {id(obj) for obj in objects.values()}
assert len(obj_list.loaded_objects) == len(obj_ids)


@pytest.mark.parametrize(**ModuleScopeTestSpec.as_parametrize_kwargs(MODULE_TEST_SPECS))
def test_load_from_definitions(objects: Mapping[str, Any], error_expected: bool) -> None:
module_fake = build_module_fake("fake", objects)
with optional_pytest_raise(
error_expected=error_expected, exception_cls=dg.DagsterInvalidDefinitionError
):
defs = load_definitions_from_module(module_fake)
obj_ids = {id(obj) for obj in all_loadable_objects_from_defs(defs)}
expected_obj_ids = {id(obj) for obj in objects.values()}
assert len(obj_ids) == len(expected_obj_ids)


def test_load_with_resources() -> None:
@dg.resource
def my_resource(): ...

module_fake = build_module_fake("foo", {"my_resource": my_resource})
defs = load_definitions_from_module(module_fake)
assert len(all_loadable_objects_from_defs(defs)) == 0
assert len(defs.resources or {}) == 0
defs = load_definitions_from_module(module_fake, resources={"foo": my_resource})
assert len(defs.resources or {}) == 1


def test_load_with_logger_defs() -> None:
@dg.logger(config_schema={})
def my_logger(init_context) -> logging.Logger: ...

module_fake = build_module_fake("foo", {"my_logger": my_logger})
defs = load_definitions_from_module(module_fake)
assert len(all_loadable_objects_from_defs(defs)) == 0
assert len(defs.resources or {}) == 0
defs = load_definitions_from_module(module_fake, resources={"foo": my_logger})
assert len(defs.resources or {}) == 1


def test_load_with_executor() -> None:
@dg.executor(name="my_executor")
def my_executor(init_context) -> dg.Executor: ...

module_fake = build_module_fake("foo", {"my_executor": my_executor})
defs = load_definitions_from_module(module_fake)
assert len(all_loadable_objects_from_defs(defs)) == 0
assert defs.executor is None
defs = load_definitions_from_module(module_fake, executor=my_executor)
assert (
defs.executor is not None
and cast(dg.ExecutorDefinition, defs.executor).name == "my_executor"
)