Skip to content

Commit

Permalink
Definitions from module loader
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Dec 19, 2024
1 parent c8a751e commit 7cf84ab
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from types import ModuleType
from typing import Optional

from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects
from dagster._core.definitions.module_loaders.utils import (
ExecutorObject,
LoggerDefinitionKeyMapping,
ResourceDefinitionMapping,
)


def load_definitions_from_module(
module: ModuleType,
resources: Optional[ResourceDefinitionMapping] = None,
loggers: Optional[LoggerDefinitionKeyMapping] = None,
executor: Optional[ExecutorObject] = None,
) -> Definitions:
return Definitions(
**ModuleScopedDagsterObjects.from_modules([module]).get_object_list().to_definitions_args(),
resources=resources,
loggers=loggers,
executor=executor,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from functools import cached_property
from types import ModuleType
from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence, Union, cast
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union, cast

from dagster._core.definitions.asset_checks import AssetChecksDefinition, has_only_asset_checks
from dagster._core.definitions.asset_key import AssetKey, CoercibleToAssetKeyPrefix
Expand All @@ -15,6 +15,7 @@
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.module_loaders.utils import (
JobDefinitionObject,
LoadableAssetObject,
LoadableDagsterObject,
RuntimeAssetObjectTypes,
RuntimeDagsterObjectTypes,
Expand Down Expand Up @@ -210,6 +211,38 @@ def cacheable_assets(self) -> Sequence[CacheableAssetsDefinition]:
asset for asset in self.loaded_objects if isinstance(asset, CacheableAssetsDefinition)
]

@cached_property
def sensors(self) -> Sequence[SensorDefinition]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, SensorDefinition)
]

@cached_property
def schedules(self) -> Sequence[ScheduleDefinitionObject]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, RuntimeScheduleObjectTypes)
]

@cached_property
def jobs(self) -> Sequence[JobDefinitionObject]:
return [
dagster_object
for dagster_object in self.loaded_objects
if isinstance(dagster_object, RuntimeJobObjectTypes)
]

@cached_property
def assets(self) -> Sequence[LoadableAssetObject]:
return [
*self.assets_defs_and_specs,
*self.source_assets,
*self.cacheable_assets,
]

def get_objects(
self, filter_fn: Callable[[LoadableDagsterObject], bool]
) -> Sequence[LoadableDagsterObject]:
Expand Down Expand Up @@ -324,6 +357,15 @@ def with_attributes(
)
return DagsterObjectsList(return_list)

def to_definitions_args(self) -> Mapping[str, Any]:
return {
"assets": self.assets,
"asset_checks": self.checks_defs,
"sensors": self.sensors,
"schedules": self.schedules,
"jobs": self.jobs,
}


def _spec_mapper_disallow_group_override(
group_name: Optional[str], automation_condition: Optional[AutomationCondition]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import logging
from contextlib import contextmanager
from types import ModuleType
from typing import Any, Mapping, Sequence, Type
from typing import Any, Mapping, Sequence, Type, cast

import dagster as dg
import pytest
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.module_loaders.load_defs_from_module import (
load_definitions_from_module,
)
from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects
from dagster._core.definitions.module_loaders.utils import LoadableDagsterObject
from dagster._record import record


Expand Down Expand Up @@ -42,6 +48,16 @@ def my_check() -> dg.AssetCheckResult:
return my_check


def all_loadable_objects_from_defs(defs: Definitions) -> Sequence[LoadableDagsterObject]:
return [
*(defs.assets or []),
*(defs.sensors or []),
*(defs.schedules or []),
*(defs.asset_checks or []),
*(defs.jobs or []),
]


@contextmanager
def optional_pytest_raise(error_expected: bool, exception_cls: Type[Exception]):
if error_expected:
Expand Down Expand Up @@ -156,3 +172,54 @@ def test_collision_detection(objects: Mapping[str, Any], error_expected: bool) -
obj_list = ModuleScopedDagsterObjects.from_modules([module_fake]).get_object_list()
obj_ids = {id(obj) for obj in objects.values()}
assert len(obj_list.loaded_objects) == len(obj_ids)


@pytest.mark.parametrize(**ModuleScopeTestSpec.as_parametrize_kwargs(MODULE_TEST_SPECS))
def test_load_from_definitions(objects: Mapping[str, Any], error_expected: bool) -> None:
module_fake = build_module_fake("fake", objects)
with optional_pytest_raise(
error_expected=error_expected, exception_cls=dg.DagsterInvalidDefinitionError
):
defs = load_definitions_from_module(module_fake)
obj_ids = {id(obj) for obj in all_loadable_objects_from_defs(defs)}
expected_obj_ids = {id(obj) for obj in objects.values()}
assert len(obj_ids) == len(expected_obj_ids)


def test_load_with_resources() -> None:
@dg.resource
def my_resource(): ...

module_fake = build_module_fake("foo", {"my_resource": my_resource})
defs = load_definitions_from_module(module_fake)
assert len(all_loadable_objects_from_defs(defs)) == 0
assert len(defs.resources or {}) == 0
defs = load_definitions_from_module(module_fake, resources={"foo": my_resource})
assert len(defs.resources or {}) == 1


def test_load_with_logger_defs() -> None:
@dg.logger(config_schema={})
def my_logger(init_context) -> logging.Logger: ...

module_fake = build_module_fake("foo", {"my_logger": my_logger})
defs = load_definitions_from_module(module_fake)
assert len(all_loadable_objects_from_defs(defs)) == 0
assert len(defs.resources or {}) == 0
defs = load_definitions_from_module(module_fake, resources={"foo": my_logger})
assert len(defs.resources or {}) == 1


def test_load_with_executor() -> None:
@dg.executor(name="my_executor")
def my_executor(init_context) -> dg.Executor: ...

module_fake = build_module_fake("foo", {"my_executor": my_executor})
defs = load_definitions_from_module(module_fake)
assert len(all_loadable_objects_from_defs(defs)) == 0
assert defs.executor is None
defs = load_definitions_from_module(module_fake, executor=my_executor)
assert (
defs.executor is not None
and cast(dg.ExecutorDefinition, defs.executor).name == "my_executor"
)

0 comments on commit 7cf84ab

Please sign in to comment.