Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AssetChecksDefinitionProps to replace untyped dictionary shenanigans #16747

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 62 additions & 30 deletions python_modules/dagster/dagster/_core/definitions/asset_checks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from abc import ABC
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Iterable,
Iterator,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
TypeVar,
)

from dagster import _check as check
Expand All @@ -29,15 +31,29 @@
if TYPE_CHECKING:
from dagster._core.definitions.assets import AssetsDefinition

TPropsType = TypeVar("TPropsType", bound=NamedTuple)
TDefType = TypeVar("TDefType")

@experimental
class AssetChecksDefinitionInputOutputProps(NamedTuple):

class IDefinitionPropsClassMatchesInit(Generic[TPropsType, TDefType], ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: something in the direction of IDefinitionWithProps / PropsBackedDefinition, i think it makes more sense to target the fundamentlal pattern then a specific thing we enforce when participating in it

@staticmethod
def from_props(props: TPropsType) -> TDefType: ...
Comment on lines +39 to +40
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this appears unused - should we instead capture [1] with a method on this class?



class AssetChecksDefinitionProps(NamedTuple):
node_def: NodeDefinition
resource_defs: Mapping[str, ResourceDefinition]
specs: Sequence[AssetCheckSpec]
asset_check_keys_by_output_name: Mapping[str, AssetCheckKey]
asset_keys_by_input_name: Mapping[str, AssetKey]


@experimental
class AssetChecksDefinition(ResourceAddable, RequiresResources):
class AssetChecksDefinition(
ResourceAddable,
RequiresResources,
IDefinitionPropsClassMatchesInit[AssetChecksDefinitionProps, "AssetChecksDefinition"],
):
"""Defines a set of checks that are produced by the same op or op graph.

AssetChecksDefinition are typically not instantiated directly, but rather produced using a
Expand All @@ -46,30 +62,52 @@ class AssetChecksDefinition(ResourceAddable, RequiresResources):

def __init__(
self,
*,
node_def: NodeDefinition,
resource_defs: Mapping[str, ResourceDefinition],
specs: Sequence[AssetCheckSpec],
input_output_props: AssetChecksDefinitionInputOutputProps
# if adding new fields, make sure to handle them in the get_attributes_dict method
asset_check_keys_by_output_name: Mapping[str, AssetCheckKey],
asset_keys_by_input_name: Mapping[str, AssetKey],
):
self._node_def = node_def
self._resource_defs = resource_defs
self._specs = check.sequence_param(specs, "specs", of_type=AssetCheckSpec)
self._input_output_props = check.inst_param(
input_output_props, "input_output_props", AssetChecksDefinitionInputOutputProps
self._props = AssetChecksDefinitionProps(
node_def=check.inst_param(node_def, "node_def", NodeDefinition),
resource_defs=check.dict_param(
resource_defs, "resource_defs", key_type=str, value_type=ResourceDefinition
),
specs=check.list_param(specs, "specs", of_type=AssetCheckSpec),
asset_check_keys_by_output_name=check.dict_param(
asset_check_keys_by_output_name,
"asset_check_keys_by_output_name",
key_type=str,
value_type=AssetCheckKey,
),
asset_keys_by_input_name=check.dict_param(
asset_keys_by_input_name,
"asset_keys_by_input_name",
key_type=str,
value_type=AssetKey,
),
)
self._specs_by_handle = {spec.key: spec for spec in specs}
self._specs_by_output_name = {
output_name: self._specs_by_handle[check_key]
for output_name, check_key in input_output_props.asset_check_keys_by_output_name.items()
for output_name, check_key in self._props.asset_check_keys_by_output_name.items()
}

@staticmethod
def from_props(props: AssetChecksDefinitionProps) -> "AssetChecksDefinition":
return AssetChecksDefinition(
node_def=props.node_def,
resource_defs=props.resource_defs,
specs=props.specs,
asset_check_keys_by_output_name=props.asset_check_keys_by_output_name,
asset_keys_by_input_name=props.asset_keys_by_input_name,
)

@public
@property
def node_def(self) -> NodeDefinition:
"""The op or op graph that can be executed to check the assets."""
return self._node_def
return self._props.node_def

@public
@property
Expand Down Expand Up @@ -109,11 +147,11 @@ def specs_by_output_name(self) -> Mapping[str, AssetCheckSpec]:

@property
def asset_keys_by_input_name(self) -> Mapping[str, AssetKey]:
return self._input_output_props.asset_keys_by_input_name
return self._props.asset_keys_by_input_name

def get_resource_requirements(self) -> Iterator[ResourceRequirement]:
yield from self.node_def.get_resource_requirements() # type: ignore[attr-defined]
for source_key, resource_def in self._resource_defs.items():
for source_key, resource_def in self._props.resource_defs.items():
yield from resource_def.get_resource_requirements(outer_context=source_key)

def get_spec_for_check_key(self, asset_check_key: AssetCheckKey) -> AssetCheckSpec:
Expand All @@ -128,20 +166,14 @@ def required_resource_keys(self) -> Set[str]:
def with_resources(
self, resource_defs: Mapping[str, ResourceDefinition]
) -> "AssetChecksDefinition":
attributes_dict = self.get_attributes_dict()
attributes_dict["resource_defs"] = merge_resource_defs(
old_resource_defs=self._resource_defs,
resource_defs_to_merge_in=resource_defs,
requires_resources=self,
)
return self.__class__(**attributes_dict)

def get_attributes_dict(self) -> Dict[str, Any]:
return dict(
node_def=self._node_def,
resource_defs=self._resource_defs,
specs=self._specs,
input_output_props=self._input_output_props,
return self.__class__(
**self._props._replace(
resource_defs=merge_resource_defs(
old_resource_defs=self._props.resource_defs,
resource_defs_to_merge_in=resource_defs,
requires_resources=self,
)
)._asdict()
Comment on lines +169 to +176
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1]

)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.asset_checks import (
AssetChecksDefinition,
AssetChecksDefinitionInputOutputProps,
)
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.events import AssetKey, CoercibleToAssetKey
Expand Down Expand Up @@ -170,12 +169,10 @@ def inner(fn: AssetCheckFunction) -> AssetChecksDefinition:
node_def=op_def,
resource_defs={},
specs=[spec],
input_output_props=AssetChecksDefinitionInputOutputProps(
asset_keys_by_input_name={
input_tuples_by_asset_key[resolved_asset_key][0]: resolved_asset_key
},
asset_check_keys_by_output_name={op_def.output_defs[0].name: spec.key},
),
asset_keys_by_input_name={
input_tuples_by_asset_key[resolved_asset_key][0]: resolved_asset_key
},
asset_check_keys_by_output_name={op_def.output_defs[0].name: spec.key},
)

return checks_def
Expand Down
15 changes: 12 additions & 3 deletions python_modules/dagster/dagster/_utils/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
from collections import defaultdict
from contextlib import contextmanager
from types import ModuleType
from typing import Any, Dict, List, Mapping, Optional, Type, cast

# top-level include is dangerous in terms of incurring circular deps
Expand Down Expand Up @@ -330,12 +331,20 @@ def check_concurrency_claim(
return claim_status.with_sleep_interval(float(self._sleep_interval))


def get_all_direct_subclasses_of_marker(marker_interface_cls: Type) -> List[Type]:
import dagster as dagster
def get_all_direct_subclasses_of_marker(
marker_interface_cls: Type, module: Optional[ModuleType] = None
) -> List[Type]:
"""Get all direct subclasses of a given marker interface class from a given module. If the
module is not specified, it defaults to the dagster module.
"""
if module is None:
import dagster as dagster

module = dagster
Comment on lines +335 to +343
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what motivated this change?


return [
symbol
for symbol in dagster.__dict__.values()
for symbol in module.__dict__.values()
if isinstance(symbol, type)
and issubclass(symbol, marker_interface_cls)
and marker_interface_cls
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from inspect import signature
from typing import List, Type

from dagster._core.definitions import asset_checks
from dagster._core.definitions.asset_checks import (
AssetChecksDefinition,
AssetChecksDefinitionProps,
IDefinitionPropsClassMatchesInit,
)
from dagster._utils.test import get_all_direct_subclasses_of_marker


def test_adherence_of_all_idef_props_class_matches_init() -> None:
def_types: List[Type] = get_all_direct_subclasses_of_marker(
IDefinitionPropsClassMatchesInit, module=asset_checks
)

hard_coded_list = [
(AssetChecksDefinition, AssetChecksDefinitionProps),
]

# keep hard_coded_list up-to-date with all subclasses of IDefinitionPropsClassMatchesInit
assert set([def_type.__name__ for def_type in def_types]) == set(
[entry[0].__name__ for entry in hard_coded_list]
), (
"You likely added a subclass of IDefinitionPropsClassMatchesInit without adding it to the"
" hard_coded_list"
)
Comment on lines +18 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just here because you couldn't figure out how to pluck TPropsType & TDefType from the class object?


for def_type, props_type in hard_coded_list:
init_params = signature(def_type.__init__).parameters
init_param_names_minus_self = set(init_params.keys()) - set(["self"])
named_tuple_fields = set(props_type._fields)
assert init_param_names_minus_self == named_tuple_fields, (
f"You have added either added a field to {props_type.__name__} or removed a parameter"
f" from {def_type.__name__}.__init__. They must remain in sync"
)