Skip to content

Commit

Permalink
[dagster_model] support default values (#22374)
Browse files Browse the repository at this point in the history
* support empty models
* support setting default values in the model declaration
  * ensure things like [] and {} are not shared

## How I Tested These Changes

added tests
  • Loading branch information
alangenfeld authored Jun 11, 2024
1 parent 7bf3597 commit b478c5f
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 23 deletions.
25 changes: 21 additions & 4 deletions python_modules/dagster/dagster/_check/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,13 +1849,30 @@ class EvalContext(NamedTuple):
local_ns: dict

@staticmethod
def capture_from_frame(depth: int) -> "EvalContext":
def capture_from_frame(
depth: int,
*,
add_to_local_ns: Optional[Mapping[str, Any]] = None,
) -> "EvalContext":
"""Capture the global and local namespaces via the stack frame.
Args:
depth: which stack frame to reference, with depth 0 being the callsite.
add_to_local_ns: A mapping of additional values to update the local namespace with.
"""
ctx_frame = sys._getframe(depth + 1) # noqa # surprisingly not costly

# copy to not mess up frame data
global_ns = ctx_frame.f_globals.copy()
local_ns = ctx_frame.f_locals.copy()

if add_to_local_ns:
local_ns.update(add_to_local_ns)

return EvalContext(
# copy to not mess up frame data
ctx_frame.f_globals.copy(),
ctx_frame.f_locals.copy(),
global_ns=global_ns,
local_ns=local_ns,
)

def update_from_frame(self, depth: int):
Expand Down
124 changes: 108 additions & 16 deletions python_modules/dagster/dagster/_model/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Mapping,
NamedTuple,
Optional,
Tuple,
Type,
TypeVar,
Union,
Expand All @@ -26,7 +27,9 @@
_MODEL_MARKER_FIELD = (
"__checkrepublic__" # "I do want to release this as checkrepublic one day" - schrockn
)
_GENERATED_NEW = "__checked_new__"
_CHECKED_NEW = "__checked_new__"
_DEFAULTS_NEW = "__defaults_new__"
_INJECTED_DEFAULT_VALS_LOCAL_VAR = "__dm_defaults__"


def _namedtuple_model_transform(
Expand All @@ -41,21 +44,48 @@ def _namedtuple_model_transform(
* bans tuple methods that don't make sense for a model object
* creates a run time checked __new__ (optional).
"""
base = NamedTuple(f"_{cls.__name__}", cls.__annotations__.items())
field_set = getattr(cls, "__annotations__", {})
defaults = {name: getattr(cls, name) for name in field_set.keys() if hasattr(cls, name)}

base = NamedTuple(f"_{cls.__name__}", field_set.items())
nt_new = base.__new__
if checked:
eval_ctx = EvalContext.capture_from_frame(
1 + decorator_frames,
# inject default values in to the local namespace for reference in generated __new__
add_to_local_ns={_INJECTED_DEFAULT_VALS_LOCAL_VAR: defaults},
)
jit_checked_new = JitCheckedNew(
cls.__annotations__,
EvalContext.capture_from_frame(1 + decorator_frames),
1 if with_new else 0,
field_set,
defaults,
base,
eval_ctx,
1 if with_new else 0,
)
base.__new__ = jit_checked_new # type: ignore

elif defaults:
# allow arbitrary ordering of default values by generating a kwarg only __new__ impl
eval_ctx = EvalContext(
global_ns={},
# inject default values in to the local namespace for reference in generated __new__
local_ns={_INJECTED_DEFAULT_VALS_LOCAL_VAR: defaults},
)
defaults_new = eval_ctx.compile_fn(
_build_defaults_new(field_set, defaults),
_DEFAULTS_NEW,
)
base.__new__ = defaults_new

if with_new and cls.__new__ is object.__new__:
# verify the alignment since it impacts frame capture
check.failed(f"Expected __new__ on {cls}, add it or switch from the _with_new decorator.")

# clear default values
for name in field_set.keys():
if hasattr(cls, name):
delattr(cls, name)

new_type = type(
cls.__name__,
(cls, base),
Expand All @@ -66,6 +96,7 @@ def _namedtuple_model_transform(
_MODEL_MARKER_FIELD: _MODEL_MARKER_VALUE,
"__annotations__": cls.__annotations__,
"__nt_new__": nt_new,
"__bool__": _true,
},
)

Expand Down Expand Up @@ -189,7 +220,7 @@ def is_dagster_model(obj) -> bool:


def has_generated_new(obj) -> bool:
return obj.__new__.__name__ == _GENERATED_NEW
return obj.__new__.__name__ in (_DEFAULTS_NEW, _CHECKED_NEW)


def as_dict(obj) -> Mapping[str, Any]:
Expand Down Expand Up @@ -228,25 +259,27 @@ def _asdict(self):


class JitCheckedNew:
"""Object that allows us to just-in-time compile a __checked_new__ implementation on first use.
"""Object that allows us to just-in-time compile a checked __new__ implementation on first use.
This has two benefits:
1. Defer processing ForwardRefs until their definitions are in scope.
2. Avoid up-front cost for unused objects.
"""

__name__ = _GENERATED_NEW
__name__ = _CHECKED_NEW

def __init__(
self,
field_set: dict,
field_set: Mapping[str, Type],
defaults: Mapping[str, Any],
nt_base: Type,
eval_ctx: EvalContext,
new_frames: int,
nt_base: Type,
):
self._field_set = field_set
self._defaults = defaults
self._nt_base = nt_base
self._eval_ctx = eval_ctx
self._new_frames = new_frames # how many frames of __new__ there are
self._nt_base = nt_base

def __call__(self, cls, **kwargs):
# update the context with callsite locals/globals to resolve
Expand All @@ -260,13 +293,14 @@ def __call__(self, cls, **kwargs):
# jit that shit
self._nt_base.__new__ = self._eval_ctx.compile_fn(
self._build_checked_new_str(),
_GENERATED_NEW,
_CHECKED_NEW,
)

return self._nt_base.__new__(cls, **kwargs)

def _build_checked_new_str(self) -> str:
kw_args = ", ".join(self._field_set.keys())
kw_args_str, set_calls_str = build_args_and_assignment_strs(self._field_set, self._defaults)

check_calls = []
for name, ttype in self._field_set.items():
call_str = build_check_call_str(
Expand All @@ -276,18 +310,76 @@ def _build_checked_new_str(self) -> str:
)
check_calls.append(f"{name}={call_str}")

check_call_block = " ,\n".join(check_calls)
check_call_block = ",\n ".join(check_calls)
return f"""
def __checked_new__(cls, {kw_args}):
def __checked_new__(cls{kw_args_str}):
{set_calls_str}
return cls.__nt_new__(
cls,{check_call_block}
cls,
{check_call_block}
)
"""


def _build_defaults_new(field_set: Mapping[str, Type], defaults: Mapping[str, Any]) -> str:
"""Build a __new__ implementation that handles default values."""
kw_args_str, set_calls_str = build_args_and_assignment_strs(field_set, defaults)
assign_str = ",\n ".join([f"{name}={name}" for name in field_set.keys()])
return f"""
def __defaults_new__(cls{kw_args_str}):
{set_calls_str}
return cls.__nt_new__(
cls,
{assign_str}
)
"""


def build_args_and_assignment_strs(
field_set: Mapping[str, Type],
defaults: Mapping[str, Any],
) -> Tuple[str, str]:
"""Utility funciton shared between _defaults_new and _checked_new to create the arguments to
the function as well as any assignment calls that need to happen.
"""
kw_args = []
set_calls = []
for arg in field_set.keys():
if arg in defaults:
default = defaults[arg]
if default is None:
kw_args.append(f"{arg} = None")
# dont share class instance of default empty containers
elif default == []:
kw_args.append(f"{arg} = None")
set_calls.append(f"{arg} = {arg} if {arg} is not None else []")
elif default == {}:
kw_args.append(f"{arg} = None")
set_calls.append(f"{arg} = {arg} if {arg} is not None else {'{}'}")
# fallback to direct reference if unknown
else:
kw_args.append(f"{arg} = {_INJECTED_DEFAULT_VALS_LOCAL_VAR}['{arg}']")
else:
kw_args.append(arg)

kw_args_str = ""
if kw_args:
kw_args_str = f", *, {', '.join(kw_args)}"

set_calls_str = ""
if set_calls:
set_calls_str = "\n ".join(set_calls)

return kw_args_str, set_calls_str


def _banned_iter(*args, **kwargs):
raise Exception("Iteration is not allowed on `@dagster_model`s.")


def _banned_idx(*args, **kwargs):
raise Exception("Index access is not allowed on `@dagster_model`s.")


def _true(_):
return True
11 changes: 11 additions & 0 deletions python_modules/dagster/dagster_tests/general_tests/test_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,10 +907,21 @@ class MyModel:
m_str = serialize_value(m, whitelist_map=test_env)
assert m == deserialize_value(m_str, whitelist_map=test_env)

@_whitelist_for_serdes(test_env)
@dagster_model(checked=False)
class UncheckedModel:
nums: List[int]
optional: int = 130

m = UncheckedModel(nums=[1, 2, 3])
m_str = serialize_value(m, whitelist_map=test_env)
assert m == deserialize_value(m_str, whitelist_map=test_env)

@_whitelist_for_serdes(test_env)
@dagster_model
class CachedModel:
nums: List[int]
optional: int = 42

@cached_method
def map(self) -> dict:
Expand Down
Loading

0 comments on commit b478c5f

Please sign in to comment.