Skip to content

Commit

Permalink
[module-loaders] Genericize object list functionality to take in all …
Browse files Browse the repository at this point in the history
…sensors, jobs, schedules (#26545)

## Summary & Motivation
Actually genericize `ModuleScopedDagsterObjects` to handle taking in all
types of Dagster objects - sensors, schedules, and jobs. I explicitly
left out resources and loggers, because I don't think it makes sense to
scoop those up at module load - since they need to be associated with a
key.

## How I Tested These Changes
Added a new test file `test_module_loaders` which scaffolds out a module
fake for a given test spec, and ensures that either an error is thrown
or the correct number of attributes are accessible on the underlying
list.
  • Loading branch information
dpeng817 authored Dec 19, 2024
1 parent ca38954 commit 620c297
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
)
from dagster._core.definitions.freshness_policy import FreshnessPolicy
from dagster._core.definitions.module_loaders.utils import (
JobDefinitionObject,
LoadableDagsterObject,
RuntimeAssetObjectTypes,
RuntimeDagsterObjectTypes,
RuntimeJobObjectTypes,
RuntimeKeyScopedAssetObjectTypes,
RuntimeScheduleObjectTypes,
ScheduleDefinitionObject,
find_objects_in_module_of_types,
key_iterator,
replace_keys_in_asset,
)
from dagster._core.definitions.sensor_definition import SensorDefinition
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.definitions.utils import DEFAULT_GROUP_NAME
from dagster._core.errors import DagsterInvalidDefinitionError
Expand All @@ -41,7 +47,7 @@ def from_modules(cls, modules: Iterable[ModuleType]) -> "ModuleScopedDagsterObje
module.__name__: list(
find_objects_in_module_of_types(
module,
(AssetsDefinition, SourceAsset, CacheableAssetsDefinition, AssetSpec),
RuntimeDagsterObjectTypes,
)
)
for module in modules
Expand Down Expand Up @@ -70,6 +76,20 @@ def assets_defs(self) -> Sequence[AssetsDefinition]:
def source_assets(self) -> Sequence[SourceAsset]:
return [asset for asset in self.deduped_objects if isinstance(asset, SourceAsset)]

@cached_property
def schedule_defs(self) -> Sequence[ScheduleDefinitionObject]:
return [
asset for asset in self.deduped_objects if isinstance(asset, RuntimeScheduleObjectTypes)
]

@cached_property
def job_objects(self) -> Sequence[JobDefinitionObject]:
return [asset for asset in self.deduped_objects if isinstance(asset, RuntimeJobObjectTypes)]

@cached_property
def sensor_defs(self) -> Sequence[SensorDefinition]:
return [asset for asset in self.deduped_objects if isinstance(asset, SensorDefinition)]

@cached_property
def module_name_by_id(self) -> Dict[int, str]:
return {
Expand All @@ -91,6 +111,7 @@ def asset_objects_by_key(
return objects_by_key

def _do_collision_detection(self) -> None:
# Collision detection on module-scoped asset objects. This does not include CacheableAssetsDefinitions, which don't have their keys defined until runtime.
for key, asset_objects in self.asset_objects_by_key.items():
# If there is more than one asset_object in the list for a given key, and the objects do not refer to the same asset_object in memory, we have a collision.
num_distinct_objects_for_key = len(
Expand All @@ -103,6 +124,44 @@ def _do_collision_detection(self) -> None:
raise DagsterInvalidDefinitionError(
f"Asset key {key.to_user_string()} is defined multiple times. Definitions found in modules: {asset_objects_str}."
)
# Collision detection on ScheduleDefinitions.
schedule_defs_by_name = defaultdict(list)
for schedule_def in self.schedule_defs:
schedule_defs_by_name[schedule_def.name].append(schedule_def)
for name, schedule_defs in schedule_defs_by_name.items():
if len(schedule_defs) > 1:
schedule_defs_str = ", ".join(
set(self.module_name_by_id[id(schedule_def)] for schedule_def in schedule_defs)
)
raise DagsterInvalidDefinitionError(
f"Schedule name {name} is defined multiple times. Definitions found in modules: {schedule_defs_str}."
)

# Collision detection on SensorDefinitions.
sensor_defs_by_name = defaultdict(list)
for sensor_def in self.sensor_defs:
sensor_defs_by_name[sensor_def.name].append(sensor_def)
for name, sensor_defs in sensor_defs_by_name.items():
if len(sensor_defs) > 1:
sensor_defs_str = ", ".join(
set(self.module_name_by_id[id(sensor_def)] for sensor_def in sensor_defs)
)
raise DagsterInvalidDefinitionError(
f"Sensor name {name} is defined multiple times. Definitions found in modules: {sensor_defs_str}."
)

# Collision detection on JobDefinitionObjects.
job_objects_by_name = defaultdict(list)
for job_object in self.job_objects:
job_objects_by_name[job_object.name].append(job_object)
for name, job_objects in job_objects_by_name.items():
if len(job_objects) > 1:
job_objects_str = ", ".join(
set(self.module_name_by_id[id(job_object)] for job_object in job_objects)
)
raise DagsterInvalidDefinitionError(
f"Job name {name} is defined multiple times. Definitions found in modules: {job_objects_str}."
)

def get_object_list(self) -> "DagsterObjectsList":
return DagsterObjectsList(self.deduped_objects)
Expand Down Expand Up @@ -231,7 +290,9 @@ def with_attributes(
)
return_list = []
for asset in dagster_object_list.loaded_objects:
if isinstance(asset, AssetsDefinition):
if not isinstance(asset, RuntimeAssetObjectTypes):
return_list.append(asset)
elif isinstance(asset, AssetsDefinition):
new_asset = asset.map_asset_specs(
_spec_mapper_disallow_group_override(group_name, automation_condition)
).with_attributes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,34 @@
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.partitioned_schedule import (
UnresolvedPartitionedAssetScheduleDefinition,
)
from dagster._core.definitions.schedule_definition import ScheduleDefinition
from dagster._core.definitions.sensor_definition import SensorDefinition
from dagster._core.definitions.source_asset import SourceAsset
from dagster._core.definitions.unresolved_asset_job_definition import UnresolvedAssetJobDefinition

LoadableAssetObject = Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition]
LoadableDagsterObject = LoadableAssetObject # For now
ScheduleDefinitionObject = Union[ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition]
JobDefinitionObject = Union[JobDefinition, UnresolvedAssetJobDefinition]
LoadableDagsterObject = Union[
LoadableAssetObject,
SensorDefinition,
ScheduleDefinitionObject,
JobDefinitionObject,
]
RuntimeKeyScopedAssetObjectTypes = (AssetsDefinition, AssetSpec, SourceAsset)
RuntimeAssetObjectTypes = (AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition)
RuntimeDagsterObjectTypes = RuntimeAssetObjectTypes # For now
RuntimeScheduleObjectTypes = (ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition)
RuntimeJobObjectTypes = (JobDefinition, UnresolvedAssetJobDefinition)
RuntimeDagsterObjectTypes = (
*RuntimeAssetObjectTypes,
SensorDefinition,
*RuntimeScheduleObjectTypes,
*RuntimeJobObjectTypes,
)


def find_objects_in_module_of_types(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from contextlib import contextmanager
from types import ModuleType
from typing import Any, Mapping, Sequence, Type

import dagster as dg
import pytest
from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects
from dagster._record import record


def build_module_fake(name: str, objects: Mapping[str, Any]) -> ModuleType:
module = ModuleType(name)
for key, value in objects.items():
setattr(module, key, value)
return module


def asset_with_key(key: str) -> dg.AssetsDefinition:
@dg.asset(key=key)
def my_asset(): ...

return my_asset


def schedule_with_name(name: str) -> dg.ScheduleDefinition:
return dg.ScheduleDefinition(name=name, cron_schedule="* * * * *", target="*")


def sensor_with_name(name: str) -> dg.SensorDefinition:
@dg.sensor(job_name="blah")
def my_sensor():
pass

return my_sensor


def check_with_key(key: str, name: str) -> dg.AssetChecksDefinition:
@dg.asset_check(asset=key, name=name)
def my_check() -> dg.AssetCheckResult:
raise Exception("ooops")

return my_check


@contextmanager
def optional_pytest_raise(error_expected: bool, exception_cls: Type[Exception]):
if error_expected:
with pytest.raises(exception_cls):
yield
else:
yield


@record
class ModuleScopeTestSpec:
objects: Mapping[str, Any]
error_expected: bool
id_: str

@staticmethod
def as_parametrize_kwargs(seq: Sequence["ModuleScopeTestSpec"]) -> Mapping[str, Any]:
return {
"argnames": "objects,error_expected",
"argvalues": [(spec.objects, spec.error_expected) for spec in seq],
"ids": [spec.id_ for spec in seq],
}


some_schedule = schedule_with_name("foo")
some_sensor = sensor_with_name("foo")
some_asset = asset_with_key("foo")
some_job = dg.define_asset_job(name="foo")
some_check = check_with_key("foo_key", "some_name")


MODULE_TEST_SPECS = [
ModuleScopeTestSpec(
objects={"foo": some_schedule}, error_expected=False, id_="single schedule"
),
ModuleScopeTestSpec(
objects={"foo": some_schedule, "bar": schedule_with_name("foo")},
error_expected=True,
id_="conflicting schedules",
),
ModuleScopeTestSpec(
objects={"foo": some_schedule, "bar": some_schedule},
error_expected=False,
id_="schedules multiple variables",
),
ModuleScopeTestSpec(objects={"foo": some_sensor}, error_expected=False, id_="single sensor"),
ModuleScopeTestSpec(
objects={"foo": some_sensor, "bar": sensor_with_name("foo")},
error_expected=True,
id_="conflicting sensors",
),
ModuleScopeTestSpec(
objects={"foo": some_sensor, "bar": some_sensor},
error_expected=False,
id_="sensors multiple variables",
),
ModuleScopeTestSpec(
objects={"foo": some_asset},
error_expected=False,
id_="asset single variable",
),
ModuleScopeTestSpec(
objects={"foo": some_asset, "bar": asset_with_key("foo")},
error_expected=True,
id_="conflicting assets",
),
ModuleScopeTestSpec(
objects={"foo": some_asset, "bar": some_asset},
error_expected=False,
id_="assets multiple variables",
),
ModuleScopeTestSpec(
objects={"foo": some_job},
error_expected=False,
id_="single job",
),
ModuleScopeTestSpec(
objects={"foo": some_job, "bar": dg.define_asset_job("other_job")},
error_expected=False,
id_="conflicting jobs",
),
ModuleScopeTestSpec(
objects={"foo": some_job, "bar": some_job},
error_expected=False,
id_="job multiple variables",
),
ModuleScopeTestSpec(
objects={"foo": some_check},
error_expected=False,
id_="single job",
),
# Currently, we do not perform any collision detection on asset checks. This is the behavior currently public in load_asset_checks_from_module.
ModuleScopeTestSpec(
objects={"foo": some_check, "bar": check_with_key("foo_key", "some_name")},
error_expected=False,
id_="conflicting checks",
),
ModuleScopeTestSpec(
objects={"foo": some_check, "bar": some_check},
error_expected=False,
id_="check multiple variables",
),
]


@pytest.mark.parametrize(**ModuleScopeTestSpec.as_parametrize_kwargs(MODULE_TEST_SPECS))
def test_collision_detection(objects: Mapping[str, Any], error_expected: bool) -> None:
module_fake = build_module_fake("fake", objects)
with optional_pytest_raise(
error_expected=error_expected, exception_cls=dg.DagsterInvalidDefinitionError
):
obj_list = ModuleScopedDagsterObjects.from_modules([module_fake]).get_object_list()
obj_ids = {id(obj) for obj in objects.values()}
assert len(obj_list.loaded_objects) == len(obj_ids)

0 comments on commit 620c297

Please sign in to comment.