diff --git a/python_modules/dagster/dagster/_check/__init__.py b/python_modules/dagster/dagster/_check/__init__.py index 58653979abf72..919e5882c9007 100644 --- a/python_modules/dagster/dagster/_check/__init__.py +++ b/python_modules/dagster/dagster/_check/__init__.py @@ -1849,13 +1849,30 @@ class EvalContext(NamedTuple): local_ns: dict @staticmethod - def capture_from_frame(depth: int) -> "EvalContext": + def capture_from_frame( + depth: int, + *, + add_to_local_ns: Optional[Mapping[str, Any]] = None, + ) -> "EvalContext": + """Capture the global and local namespaces via the stack frame. + + Args: + depth: which stack frame to reference, with depth 0 being the callsite. + add_to_local_ns: A mapping of additional values to update the local namespace with. + + """ ctx_frame = sys._getframe(depth + 1) # noqa # surprisingly not costly + # copy to not mess up frame data + global_ns = ctx_frame.f_globals.copy() + local_ns = ctx_frame.f_locals.copy() + + if add_to_local_ns: + local_ns.update(add_to_local_ns) + return EvalContext( - # copy to not mess up frame data - ctx_frame.f_globals.copy(), - ctx_frame.f_locals.copy(), + global_ns=global_ns, + local_ns=local_ns, ) def update_from_frame(self, depth: int): diff --git a/python_modules/dagster/dagster/_model/decorator.py b/python_modules/dagster/dagster/_model/decorator.py index 70c25815316d0..c136edeaeb04d 100644 --- a/python_modules/dagster/dagster/_model/decorator.py +++ b/python_modules/dagster/dagster/_model/decorator.py @@ -7,6 +7,7 @@ Mapping, NamedTuple, Optional, + Tuple, Type, TypeVar, Union, @@ -26,7 +27,9 @@ _MODEL_MARKER_FIELD = ( "__checkrepublic__" # "I do want to release this as checkrepublic one day" - schrockn ) -_GENERATED_NEW = "__checked_new__" +_CHECKED_NEW = "__checked_new__" +_DEFAULTS_NEW = "__defaults_new__" +_INJECTED_DEFAULT_VALS_LOCAL_VAR = "__dm_defaults__" def _namedtuple_model_transform( @@ -41,21 +44,48 @@ def _namedtuple_model_transform( * bans tuple methods that don't make sense for a model object * creates a run time checked __new__ (optional). """ - base = NamedTuple(f"_{cls.__name__}", cls.__annotations__.items()) + field_set = getattr(cls, "__annotations__", {}) + defaults = {name: getattr(cls, name) for name in field_set.keys() if hasattr(cls, name)} + + base = NamedTuple(f"_{cls.__name__}", field_set.items()) nt_new = base.__new__ if checked: + eval_ctx = EvalContext.capture_from_frame( + 1 + decorator_frames, + # inject default values in to the local namespace for reference in generated __new__ + add_to_local_ns={_INJECTED_DEFAULT_VALS_LOCAL_VAR: defaults}, + ) jit_checked_new = JitCheckedNew( - cls.__annotations__, - EvalContext.capture_from_frame(1 + decorator_frames), - 1 if with_new else 0, + field_set, + defaults, base, + eval_ctx, + 1 if with_new else 0, ) base.__new__ = jit_checked_new # type: ignore + elif defaults: + # allow arbitrary ordering of default values by generating a kwarg only __new__ impl + eval_ctx = EvalContext( + global_ns={}, + # inject default values in to the local namespace for reference in generated __new__ + local_ns={_INJECTED_DEFAULT_VALS_LOCAL_VAR: defaults}, + ) + defaults_new = eval_ctx.compile_fn( + _build_defaults_new(field_set, defaults), + _DEFAULTS_NEW, + ) + base.__new__ = defaults_new + if with_new and cls.__new__ is object.__new__: # verify the alignment since it impacts frame capture check.failed(f"Expected __new__ on {cls}, add it or switch from the _with_new decorator.") + # clear default values + for name in field_set.keys(): + if hasattr(cls, name): + delattr(cls, name) + new_type = type( cls.__name__, (cls, base), @@ -66,6 +96,7 @@ def _namedtuple_model_transform( _MODEL_MARKER_FIELD: _MODEL_MARKER_VALUE, "__annotations__": cls.__annotations__, "__nt_new__": nt_new, + "__bool__": _true, }, ) @@ -189,7 +220,7 @@ def is_dagster_model(obj) -> bool: def has_generated_new(obj) -> bool: - return obj.__new__.__name__ == _GENERATED_NEW + return obj.__new__.__name__ in (_DEFAULTS_NEW, _CHECKED_NEW) def as_dict(obj) -> Mapping[str, Any]: @@ -228,25 +259,27 @@ def _asdict(self): class JitCheckedNew: - """Object that allows us to just-in-time compile a __checked_new__ implementation on first use. + """Object that allows us to just-in-time compile a checked __new__ implementation on first use. This has two benefits: 1. Defer processing ForwardRefs until their definitions are in scope. 2. Avoid up-front cost for unused objects. """ - __name__ = _GENERATED_NEW + __name__ = _CHECKED_NEW def __init__( self, - field_set: dict, + field_set: Mapping[str, Type], + defaults: Mapping[str, Any], + nt_base: Type, eval_ctx: EvalContext, new_frames: int, - nt_base: Type, ): self._field_set = field_set + self._defaults = defaults + self._nt_base = nt_base self._eval_ctx = eval_ctx self._new_frames = new_frames # how many frames of __new__ there are - self._nt_base = nt_base def __call__(self, cls, **kwargs): # update the context with callsite locals/globals to resolve @@ -260,13 +293,14 @@ def __call__(self, cls, **kwargs): # jit that shit self._nt_base.__new__ = self._eval_ctx.compile_fn( self._build_checked_new_str(), - _GENERATED_NEW, + _CHECKED_NEW, ) return self._nt_base.__new__(cls, **kwargs) def _build_checked_new_str(self) -> str: - kw_args = ", ".join(self._field_set.keys()) + kw_args_str, set_calls_str = build_args_and_assignment_strs(self._field_set, self._defaults) + check_calls = [] for name, ttype in self._field_set.items(): call_str = build_check_call_str( @@ -276,18 +310,76 @@ def _build_checked_new_str(self) -> str: ) check_calls.append(f"{name}={call_str}") - check_call_block = " ,\n".join(check_calls) + check_call_block = ",\n ".join(check_calls) return f""" -def __checked_new__(cls, {kw_args}): +def __checked_new__(cls{kw_args_str}): + {set_calls_str} return cls.__nt_new__( - cls,{check_call_block} + cls, + {check_call_block} ) """ +def _build_defaults_new(field_set: Mapping[str, Type], defaults: Mapping[str, Any]) -> str: + """Build a __new__ implementation that handles default values.""" + kw_args_str, set_calls_str = build_args_and_assignment_strs(field_set, defaults) + assign_str = ",\n ".join([f"{name}={name}" for name in field_set.keys()]) + return f""" +def __defaults_new__(cls{kw_args_str}): + {set_calls_str} + return cls.__nt_new__( + cls, + {assign_str} + ) + """ + + +def build_args_and_assignment_strs( + field_set: Mapping[str, Type], + defaults: Mapping[str, Any], +) -> Tuple[str, str]: + """Utility funciton shared between _defaults_new and _checked_new to create the arguments to + the function as well as any assignment calls that need to happen. + """ + kw_args = [] + set_calls = [] + for arg in field_set.keys(): + if arg in defaults: + default = defaults[arg] + if default is None: + kw_args.append(f"{arg} = None") + # dont share class instance of default empty containers + elif default == []: + kw_args.append(f"{arg} = None") + set_calls.append(f"{arg} = {arg} if {arg} is not None else []") + elif default == {}: + kw_args.append(f"{arg} = None") + set_calls.append(f"{arg} = {arg} if {arg} is not None else {'{}'}") + # fallback to direct reference if unknown + else: + kw_args.append(f"{arg} = {_INJECTED_DEFAULT_VALS_LOCAL_VAR}['{arg}']") + else: + kw_args.append(arg) + + kw_args_str = "" + if kw_args: + kw_args_str = f", *, {', '.join(kw_args)}" + + set_calls_str = "" + if set_calls: + set_calls_str = "\n ".join(set_calls) + + return kw_args_str, set_calls_str + + def _banned_iter(*args, **kwargs): raise Exception("Iteration is not allowed on `@dagster_model`s.") def _banned_idx(*args, **kwargs): raise Exception("Index access is not allowed on `@dagster_model`s.") + + +def _true(_): + return True diff --git a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py index f4b774341ae66..379eea540cd35 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_serdes.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_serdes.py @@ -907,10 +907,21 @@ class MyModel: m_str = serialize_value(m, whitelist_map=test_env) assert m == deserialize_value(m_str, whitelist_map=test_env) + @_whitelist_for_serdes(test_env) + @dagster_model(checked=False) + class UncheckedModel: + nums: List[int] + optional: int = 130 + + m = UncheckedModel(nums=[1, 2, 3]) + m_str = serialize_value(m, whitelist_map=test_env) + assert m == deserialize_value(m_str, whitelist_map=test_env) + @_whitelist_for_serdes(test_env) @dagster_model class CachedModel: nums: List[int] + optional: int = 42 @cached_method def map(self) -> dict: diff --git a/python_modules/dagster/dagster_tests/model_tests/test_dagster_model.py b/python_modules/dagster/dagster_tests/model_tests/test_dagster_model.py index af3374fdf0cd0..1505834699776 100644 --- a/python_modules/dagster/dagster_tests/model_tests/test_dagster_model.py +++ b/python_modules/dagster/dagster_tests/model_tests/test_dagster_model.py @@ -1,9 +1,12 @@ -from typing import Optional +from typing import Any, Dict, List, Optional import pytest from dagster._check import CheckError -from dagster._model import DagsterModel, copy, dagster_model, dagster_model_custom -from dagster._model.decorator import IHaveNew +from dagster._model import DagsterModel, IHaveNew, copy, dagster_model, dagster_model_custom +from dagster._model.decorator import ( + _INJECTED_DEFAULT_VALS_LOCAL_VAR, + build_args_and_assignment_strs, +) from dagster._utils.cached_method import CACHED_METHOD_CACHE_FIELD, cached_method from pydantic import ValidationError @@ -288,3 +291,113 @@ class Failed: @dagster_model_custom class FailedAgain: local: Optional[str] + + +def test_empty(): + @dagster_model + class Empty: ... + + assert Empty() + + +def test_optional_arg() -> None: + @dagster_model + class Opt: + maybe: Optional[str] = None + always: Optional[str] + + assert Opt(always="set") + assert Opt(always="set", maybe="x").maybe == "x" + + @dagster_model(checked=False) + class Other: + maybe: Optional[str] = None + always: Optional[str] + + assert Other(always="set") + assert Other(always="set", maybe="x").maybe == "x" + + +def test_dont_share_containers() -> None: + @dagster_model + class Empties: + items: List[str] = [] + map: Dict[str, str] = {} + + e_1 = Empties() + e_2 = Empties() + assert e_1.items is not e_2.items + assert e_1.map is not e_2.map + + +def test_sentinel(): + _unset = object() + + @dagster_model + class Sample: + val: Optional[Any] = _unset + + assert Sample().val is _unset + assert Sample(val=None).val is None + + @dagster_model(checked=False) + class OtherSample: + val: Optional[Any] = _unset + + assert OtherSample().val is _unset + assert OtherSample(val=None).val is None + + +@pytest.mark.parametrize( + "fields, defaults, expected", + [ + ( + {"name": str}, + {}, + ( + ", *, name", + "", + ), + ), + # defaults dont need to be in certain order since we force kwargs + # None handled directly by arg default + ( + {"name": str, "age": int, "f": float}, + {"age": None}, + ( + ", *, name, age = None, f", + "", + ), + ), + # empty container defaults get fresh copies via assignments + ( + {"things": list}, + {"things": []}, + ( + ", *, things = None", + "things = things if things is not None else []", + ), + ), + ( + {"map": dict}, + {"map": {}}, + ( + ", *, map = None", + "map = map if map is not None else {}", + ), + ), + # base case - default values resolved by reference to injected local + ( + {"val": Any}, + {"val": object()}, + ( + f", *, val = {_INJECTED_DEFAULT_VALS_LOCAL_VAR}['val']", + "", + ), + ), + ], +) +def test_build_args_and_assign(fields, defaults, expected): + # tests / documents shared utility fn + # don't hesitate to delete this upon refactor + assert build_args_and_assignment_strs(fields, defaults) == expected