From c8a751e985ad381e729438f4ba38b46bd0da4fa3 Mon Sep 17 00:00:00 2001 From: Chris DeCarolis Date: Tue, 17 Dec 2024 08:51:23 -0800 Subject: [PATCH] Genericize object list functionality to take in all sensors, jobs, schedules --- .../definitions/module_loaders/object_list.py | 65 ++++++- .../_core/definitions/module_loaders/utils.py | 25 ++- .../test_module_loaders.py | 158 ++++++++++++++++++ 3 files changed, 244 insertions(+), 4 deletions(-) create mode 100644 python_modules/dagster/dagster_tests/definitions_tests/module_loader_tests/test_module_loaders.py diff --git a/python_modules/dagster/dagster/_core/definitions/module_loaders/object_list.py b/python_modules/dagster/dagster/_core/definitions/module_loaders/object_list.py index 597ddb625a56e..2252f0f8bb6d6 100644 --- a/python_modules/dagster/dagster/_core/definitions/module_loaders/object_list.py +++ b/python_modules/dagster/dagster/_core/definitions/module_loaders/object_list.py @@ -14,13 +14,19 @@ ) from dagster._core.definitions.freshness_policy import FreshnessPolicy from dagster._core.definitions.module_loaders.utils import ( + JobDefinitionObject, LoadableDagsterObject, RuntimeAssetObjectTypes, + RuntimeDagsterObjectTypes, + RuntimeJobObjectTypes, RuntimeKeyScopedAssetObjectTypes, + RuntimeScheduleObjectTypes, + ScheduleDefinitionObject, find_objects_in_module_of_types, key_iterator, replace_keys_in_asset, ) +from dagster._core.definitions.sensor_definition import SensorDefinition from dagster._core.definitions.source_asset import SourceAsset from dagster._core.definitions.utils import DEFAULT_GROUP_NAME from dagster._core.errors import DagsterInvalidDefinitionError @@ -41,7 +47,7 @@ def from_modules(cls, modules: Iterable[ModuleType]) -> "ModuleScopedDagsterObje module.__name__: list( find_objects_in_module_of_types( module, - (AssetsDefinition, SourceAsset, CacheableAssetsDefinition, AssetSpec), + RuntimeDagsterObjectTypes, ) ) for module in modules @@ -70,6 +76,20 @@ def assets_defs(self) -> Sequence[AssetsDefinition]: def source_assets(self) -> Sequence[SourceAsset]: return [asset for asset in self.deduped_objects if isinstance(asset, SourceAsset)] + @cached_property + def schedule_defs(self) -> Sequence[ScheduleDefinitionObject]: + return [ + asset for asset in self.deduped_objects if isinstance(asset, RuntimeScheduleObjectTypes) + ] + + @cached_property + def job_objects(self) -> Sequence[JobDefinitionObject]: + return [asset for asset in self.deduped_objects if isinstance(asset, RuntimeJobObjectTypes)] + + @cached_property + def sensor_defs(self) -> Sequence[SensorDefinition]: + return [asset for asset in self.deduped_objects if isinstance(asset, SensorDefinition)] + @cached_property def module_name_by_id(self) -> Dict[int, str]: return { @@ -91,6 +111,7 @@ def asset_objects_by_key( return objects_by_key def _do_collision_detection(self) -> None: + # Collision detection on module-scoped asset objects. This does not include CacheableAssetsDefinitions, which don't have their keys defined until runtime. for key, asset_objects in self.asset_objects_by_key.items(): # If there is more than one asset_object in the list for a given key, and the objects do not refer to the same asset_object in memory, we have a collision. num_distinct_objects_for_key = len( @@ -103,6 +124,44 @@ def _do_collision_detection(self) -> None: raise DagsterInvalidDefinitionError( f"Asset key {key.to_user_string()} is defined multiple times. Definitions found in modules: {asset_objects_str}." ) + # Collision detection on ScheduleDefinitions. + schedule_defs_by_name = defaultdict(list) + for schedule_def in self.schedule_defs: + schedule_defs_by_name[schedule_def.name].append(schedule_def) + for name, schedule_defs in schedule_defs_by_name.items(): + if len(schedule_defs) > 1: + schedule_defs_str = ", ".join( + set(self.module_name_by_id[id(schedule_def)] for schedule_def in schedule_defs) + ) + raise DagsterInvalidDefinitionError( + f"Schedule name {name} is defined multiple times. Definitions found in modules: {schedule_defs_str}." + ) + + # Collision detection on SensorDefinitions. + sensor_defs_by_name = defaultdict(list) + for sensor_def in self.sensor_defs: + sensor_defs_by_name[sensor_def.name].append(sensor_def) + for name, sensor_defs in sensor_defs_by_name.items(): + if len(sensor_defs) > 1: + sensor_defs_str = ", ".join( + set(self.module_name_by_id[id(sensor_def)] for sensor_def in sensor_defs) + ) + raise DagsterInvalidDefinitionError( + f"Sensor name {name} is defined multiple times. Definitions found in modules: {sensor_defs_str}." + ) + + # Collision detection on JobDefinitionObjects. + job_objects_by_name = defaultdict(list) + for job_object in self.job_objects: + job_objects_by_name[job_object.name].append(job_object) + for name, job_objects in job_objects_by_name.items(): + if len(job_objects) > 1: + job_objects_str = ", ".join( + set(self.module_name_by_id[id(job_object)] for job_object in job_objects) + ) + raise DagsterInvalidDefinitionError( + f"Job name {name} is defined multiple times. Definitions found in modules: {job_objects_str}." + ) def get_object_list(self) -> "DagsterObjectsList": return DagsterObjectsList(self.deduped_objects) @@ -231,7 +290,9 @@ def with_attributes( ) return_list = [] for asset in dagster_object_list.loaded_objects: - if isinstance(asset, AssetsDefinition): + if not isinstance(asset, RuntimeAssetObjectTypes): + return_list.append(asset) + elif isinstance(asset, AssetsDefinition): new_asset = asset.map_asset_specs( _spec_mapper_disallow_group_override(group_name, automation_condition) ).with_attributes( diff --git a/python_modules/dagster/dagster/_core/definitions/module_loaders/utils.py b/python_modules/dagster/dagster/_core/definitions/module_loaders/utils.py index bd5f40f8765d9..8a4926f5f8714 100644 --- a/python_modules/dagster/dagster/_core/definitions/module_loaders/utils.py +++ b/python_modules/dagster/dagster/_core/definitions/module_loaders/utils.py @@ -7,13 +7,34 @@ from dagster._core.definitions.asset_spec import AssetSpec from dagster._core.definitions.assets import AssetsDefinition from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition +from dagster._core.definitions.job_definition import JobDefinition +from dagster._core.definitions.partitioned_schedule import ( + UnresolvedPartitionedAssetScheduleDefinition, +) +from dagster._core.definitions.schedule_definition import ScheduleDefinition +from dagster._core.definitions.sensor_definition import SensorDefinition from dagster._core.definitions.source_asset import SourceAsset +from dagster._core.definitions.unresolved_asset_job_definition import UnresolvedAssetJobDefinition LoadableAssetObject = Union[AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition] -LoadableDagsterObject = LoadableAssetObject # For now +ScheduleDefinitionObject = Union[ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition] +JobDefinitionObject = Union[JobDefinition, UnresolvedAssetJobDefinition] +LoadableDagsterObject = Union[ + LoadableAssetObject, + SensorDefinition, + ScheduleDefinitionObject, + JobDefinitionObject, +] RuntimeKeyScopedAssetObjectTypes = (AssetsDefinition, AssetSpec, SourceAsset) RuntimeAssetObjectTypes = (AssetsDefinition, AssetSpec, SourceAsset, CacheableAssetsDefinition) -RuntimeDagsterObjectTypes = RuntimeAssetObjectTypes # For now +RuntimeScheduleObjectTypes = (ScheduleDefinition, UnresolvedPartitionedAssetScheduleDefinition) +RuntimeJobObjectTypes = (JobDefinition, UnresolvedAssetJobDefinition) +RuntimeDagsterObjectTypes = ( + *RuntimeAssetObjectTypes, + SensorDefinition, + *RuntimeScheduleObjectTypes, + *RuntimeJobObjectTypes, +) def find_objects_in_module_of_types( diff --git a/python_modules/dagster/dagster_tests/definitions_tests/module_loader_tests/test_module_loaders.py b/python_modules/dagster/dagster_tests/definitions_tests/module_loader_tests/test_module_loaders.py new file mode 100644 index 0000000000000..ccdbd5492a678 --- /dev/null +++ b/python_modules/dagster/dagster_tests/definitions_tests/module_loader_tests/test_module_loaders.py @@ -0,0 +1,158 @@ +from contextlib import contextmanager +from types import ModuleType +from typing import Any, Mapping, Sequence, Type + +import dagster as dg +import pytest +from dagster._core.definitions.module_loaders.object_list import ModuleScopedDagsterObjects +from dagster._record import record + + +def build_module_fake(name: str, objects: Mapping[str, Any]) -> ModuleType: + module = ModuleType(name) + for key, value in objects.items(): + setattr(module, key, value) + return module + + +def asset_with_key(key: str) -> dg.AssetsDefinition: + @dg.asset(key=key) + def my_asset(): ... + + return my_asset + + +def schedule_with_name(name: str) -> dg.ScheduleDefinition: + return dg.ScheduleDefinition(name=name, cron_schedule="* * * * *", target="*") + + +def sensor_with_name(name: str) -> dg.SensorDefinition: + @dg.sensor(job_name="blah") + def my_sensor(): + pass + + return my_sensor + + +def check_with_key(key: str, name: str) -> dg.AssetChecksDefinition: + @dg.asset_check(asset=key, name=name) + def my_check() -> dg.AssetCheckResult: + raise Exception("ooops") + + return my_check + + +@contextmanager +def optional_pytest_raise(error_expected: bool, exception_cls: Type[Exception]): + if error_expected: + with pytest.raises(exception_cls): + yield + else: + yield + + +@record +class ModuleScopeTestSpec: + objects: Mapping[str, Any] + error_expected: bool + id_: str + + @staticmethod + def as_parametrize_kwargs(seq: Sequence["ModuleScopeTestSpec"]) -> Mapping[str, Any]: + return { + "argnames": "objects,error_expected", + "argvalues": [(spec.objects, spec.error_expected) for spec in seq], + "ids": [spec.id_ for spec in seq], + } + + +some_schedule = schedule_with_name("foo") +some_sensor = sensor_with_name("foo") +some_asset = asset_with_key("foo") +some_job = dg.define_asset_job(name="foo") +some_check = check_with_key("foo_key", "some_name") + + +MODULE_TEST_SPECS = [ + ModuleScopeTestSpec( + objects={"foo": some_schedule}, error_expected=False, id_="single schedule" + ), + ModuleScopeTestSpec( + objects={"foo": some_schedule, "bar": schedule_with_name("foo")}, + error_expected=True, + id_="conflicting schedules", + ), + ModuleScopeTestSpec( + objects={"foo": some_schedule, "bar": some_schedule}, + error_expected=False, + id_="schedules multiple variables", + ), + ModuleScopeTestSpec(objects={"foo": some_sensor}, error_expected=False, id_="single sensor"), + ModuleScopeTestSpec( + objects={"foo": some_sensor, "bar": sensor_with_name("foo")}, + error_expected=True, + id_="conflicting sensors", + ), + ModuleScopeTestSpec( + objects={"foo": some_sensor, "bar": some_sensor}, + error_expected=False, + id_="sensors multiple variables", + ), + ModuleScopeTestSpec( + objects={"foo": some_asset}, + error_expected=False, + id_="asset single variable", + ), + ModuleScopeTestSpec( + objects={"foo": some_asset, "bar": asset_with_key("foo")}, + error_expected=True, + id_="conflicting assets", + ), + ModuleScopeTestSpec( + objects={"foo": some_asset, "bar": some_asset}, + error_expected=False, + id_="assets multiple variables", + ), + ModuleScopeTestSpec( + objects={"foo": some_job}, + error_expected=False, + id_="single job", + ), + ModuleScopeTestSpec( + objects={"foo": some_job, "bar": dg.define_asset_job("other_job")}, + error_expected=False, + id_="conflicting jobs", + ), + ModuleScopeTestSpec( + objects={"foo": some_job, "bar": some_job}, + error_expected=False, + id_="job multiple variables", + ), + ModuleScopeTestSpec( + objects={"foo": some_check}, + error_expected=False, + id_="single job", + ), + # Currently, we do not perform any collision detection on asset checks. This is the behavior currently public in load_asset_checks_from_module. + ModuleScopeTestSpec( + objects={"foo": some_check, "bar": check_with_key("foo_key", "some_name")}, + error_expected=False, + id_="conflicting checks", + ), + ModuleScopeTestSpec( + objects={"foo": some_check, "bar": some_check}, + error_expected=False, + id_="check multiple variables", + ), +] + + +@pytest.mark.parametrize(**ModuleScopeTestSpec.as_parametrize_kwargs(MODULE_TEST_SPECS)) +def test_collision_detection(objects: Mapping[str, Any], error_expected: bool) -> None: + module_fake = build_module_fake("fake", objects) + with optional_pytest_raise( + error_expected=error_expected, exception_cls=dg.DagsterInvalidDefinitionError + ): + obj_list = ModuleScopedDagsterObjects.from_modules([module_fake]).get_object_list() + obj_ids = {id(obj) for obj in objects.values()} + assert len(obj_list.loaded_objects) == len(obj_ids)