Skip to content

Commit

Permalink
[record] kw_only (#26600)
Browse files Browse the repository at this point in the history
add native support for posargs records to help with cases like
migirating public APIs

## How I Tested These Changes

added tests
bk all python versions
  • Loading branch information
alangenfeld authored Dec 19, 2024
1 parent f4ca960 commit 15bf894
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 75 deletions.
36 changes: 7 additions & 29 deletions python_modules/dagster/dagster/_core/definitions/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def repository_selector(self) -> "RepositorySelector":

@whitelist_for_serdes
@record
class RepositorySelector(IHaveNew):
class RepositorySelector:
location_name: str
repository_name: str

Expand Down Expand Up @@ -167,17 +167,10 @@ def from_graphql_input(graphql_data):
)


@record_custom
class CodeLocationSelector(IHaveNew):
@record(kw_only=False)
class CodeLocationSelector:
location_name: str

# allow posargs to avoid breaking change
def __new__(cls, location_name: str):
return super().__new__(
cls,
location_name=location_name,
)

def to_repository_selector(self) -> RepositorySelector:
return RepositorySelector(
location_name=self.location_name,
Expand Down Expand Up @@ -318,21 +311,13 @@ def to_graphql_input(self):
}


@record_custom
class PartitionRangeSelector(IHaveNew):
@record(kw_only=False)
class PartitionRangeSelector:
"""The information needed to resolve a partition range."""

start: str
end: str

# allow posargs
def __new__(cls, start: str, end: str):
return super().__new__(
cls,
start=start,
end=end,
)

def to_graphql_input(self):
return {
"start": self.start,
Expand All @@ -347,19 +332,12 @@ def from_graphql_input(graphql_data):
)


@record_custom
class PartitionsSelector(IHaveNew):
@record(kw_only=False)
class PartitionsSelector:
"""The information needed to define selection partitions."""

ranges: Sequence[PartitionRangeSelector]

# allow posargs
def __new__(cls, ranges: Sequence[PartitionRangeSelector]):
return super().__new__(
cls,
ranges=ranges,
)

def to_graphql_input(self):
return {"ranges": [partition_range.to_graphql_input() for partition_range in self.ranges]}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
RegisteredCodeLocationOrigin,
RemoteRepositoryOrigin,
)
from dagster._record import IHaveNew, record, record_custom
from dagster._record import record
from dagster._serdes.serdes import whitelist_for_serdes

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,19 +84,11 @@ def for_test(
)


@record_custom
class JobHandle(IHaveNew):
@record(kw_only=False)
class JobHandle:
job_name: str
repository_handle: RepositoryHandle

# allow posargs
def __new__(cls, job_name: str, repository_handle: RepositoryHandle):
return super().__new__(
cls,
job_name=job_name,
repository_handle=repository_handle,
)

def to_string(self):
return f"{self.location_name}.{self.repository_name}.{self.job_name}"

Expand All @@ -123,19 +115,11 @@ def to_selector(self) -> JobSubsetSelector:
)


@record_custom
class InstigatorHandle(IHaveNew):
@record(kw_only=False)
class InstigatorHandle:
instigator_name: str
repository_handle: RepositoryHandle

# allow posargs
def __new__(cls, instigator_name: str, repository_handle: RepositoryHandle):
return super().__new__(
cls,
instigator_name=instigator_name,
repository_handle=repository_handle,
)

@property
def repository_name(self) -> str:
return self.repository_handle.repository_name
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/_core/snap/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def get_output_snap(self, name: str) -> OutputDefSnap:
},
)
@record
class NodeDefsSnapshot(IHaveNew):
class NodeDefsSnapshot:
op_def_snaps: Sequence[OpDefSnap]
graph_def_snaps: Sequence[GraphDefSnap]

Expand Down
79 changes: 56 additions & 23 deletions python_modules/dagster/dagster/_record/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_NAMED_TUPLE_BASE_NEW_FIELD = "__nt_new__"
_REMAPPING_FIELD = "__field_remap__"
_ORIGINAL_CLASS_FIELD = "__original_class__"
_KW_ONLY_FIELD = "__kw_only__"


_sample_nt = namedtuple("_canary", "x")
Expand All @@ -44,10 +45,12 @@

def _get_field_set_and_defaults(
cls: Type,
kw_only: bool,
) -> Tuple[Mapping[str, Any], Mapping[str, Any]]:
field_set = getattr(cls, "__annotations__", {})
defaults = {}

last_defaulted_field = None
for name in field_set.keys():
if hasattr(cls, name):
attr_val = getattr(cls, name)
Expand All @@ -57,23 +60,35 @@ def _get_field_set_and_defaults(
f"Conflicting non-abstract @property for field {name} on record {cls.__name__}."
"Add the the @abstractmethod decorator to make it abstract.",
)
elif isinstance(attr_val, _tuple_getter_type):
# When doing record inheritance, filter out tuplegetters from parents.
# This workaround only seems needed for py3.8
continue
else:
# When doing record inheritance, filter out tuplegetters from parents.
# This workaround only seems needed for py3.9
elif not isinstance(attr_val, _tuple_getter_type):
check.invariant(
not inspect.isfunction(attr_val),
f"Conflicting function for field {name} on record {cls.__name__}. "
"If you are trying to set a function as a default value "
"you will have to override __new__.",
)
defaults[name] = attr_val
last_defaulted_field = name
continue

# fall through here means no default set
if last_defaulted_field and not kw_only:
check.failed(
"Fields without defaults cannot appear after fields with default values. "
f"Field {name} has no default after {last_defaulted_field} with default value."
)

for base in cls.__bases__:
if is_record(base):
original_base = getattr(base, _ORIGINAL_CLASS_FIELD)
base_field_set, base_defaults = _get_field_set_and_defaults(original_base)
base_kw_only = getattr(base, _KW_ONLY_FIELD)
check.invariant(
kw_only == base_kw_only,
"Can not inherit from a parent @record with different kw_only setting.",
)
base_field_set, base_defaults = _get_field_set_and_defaults(original_base, kw_only)
field_set = {**base_field_set, **field_set}
defaults = {**base_defaults, **defaults}

Expand All @@ -87,13 +102,14 @@ def _namedtuple_record_transform(
with_new: bool,
decorator_frames: int,
field_to_new_mapping: Optional[Mapping[str, str]],
kw_only: bool,
) -> TType:
"""Transforms the input class in to one that inherits a generated NamedTuple base class
and:
* bans tuple methods that don't make sense for a record object
* creates a run time checked __new__ (optional).
"""
field_set, defaults = _get_field_set_and_defaults(cls)
field_set, defaults = _get_field_set_and_defaults(cls, kw_only)

base = NamedTuple(f"_{cls.__name__}", field_set.items())
nt_new = base.__new__
Expand All @@ -109,7 +125,8 @@ def _namedtuple_record_transform(
field_set,
defaults,
eval_ctx,
1 if with_new else 0,
new_frames=1 if with_new else 0,
kw_only=kw_only,
)
elif defaults:
# allow arbitrary ordering of default values by generating a kwarg only __new__ impl
Expand All @@ -120,7 +137,7 @@ def _namedtuple_record_transform(
lazy_imports={},
)
generated_new = eval_ctx.compile_fn(
_build_defaults_new(field_set, defaults),
_build_defaults_new(field_set, defaults, kw_only),
_DEFAULTS_NEW,
)

Expand All @@ -145,6 +162,7 @@ def _namedtuple_record_transform(
_NAMED_TUPLE_BASE_NEW_FIELD: nt_new,
_REMAPPING_FIELD: field_to_new_mapping or {},
_ORIGINAL_CLASS_FIELD: cls,
_KW_ONLY_FIELD: kw_only,
"__reduce__": _reduce,
# functools doesn't work, so manually update_wrapper
"__module__": cls.__module__,
Expand Down Expand Up @@ -219,6 +237,7 @@ def record(
def record(
*,
checked: bool = True,
kw_only: bool = True,
) -> Callable[[TType], TType]: ... # Overload for using decorator used with args.


Expand All @@ -230,11 +249,13 @@ def record(
cls: Optional[TType] = None,
*,
checked: bool = True,
kw_only: bool = True,
) -> Union[TType, Callable[[TType], TType]]:
"""A class decorator that will create an immutable record class based on the defined fields.
Args:
checked: Whether or not to generate runtime type checked construction.
checked: Whether or not to generate runtime type checked construction (default True).
kw_only: Whether or not the generated __new__ is kwargs only (default True).
"""
if cls:
return _namedtuple_record_transform(
Expand All @@ -243,6 +264,7 @@ def record(
with_new=False,
decorator_frames=1,
field_to_new_mapping=None,
kw_only=kw_only,
)
else:
return partial(
Expand All @@ -251,6 +273,7 @@ def record(
with_new=False,
decorator_frames=0,
field_to_new_mapping=None,
kw_only=kw_only,
)


Expand Down Expand Up @@ -303,6 +326,7 @@ def __new__(cls, name: Optional[str] = None)
with_new=True,
decorator_frames=1,
field_to_new_mapping=field_to_new_mapping,
kw_only=True,
)
else:
return partial(
Expand All @@ -311,6 +335,7 @@ def __new__(cls, name: Optional[str] = None)
with_new=True,
decorator_frames=0,
field_to_new_mapping=field_to_new_mapping,
kw_only=True,
)


Expand Down Expand Up @@ -429,12 +454,14 @@ def __init__(
defaults: Mapping[str, Any],
eval_ctx: EvalContext,
new_frames: int,
kw_only: bool,
):
self._field_set = field_set
self._defaults = defaults
self._eval_ctx = eval_ctx
self._new_frames = new_frames # how many frames of __new__ there are
self._compiled = False
self._kw_only = kw_only

def __call__(self, cls, *args, **kwargs):
if _do_defensive_checks():
Expand Down Expand Up @@ -470,7 +497,11 @@ def __call__(self, cls, *args, **kwargs):
return compiled_fn(cls, *args, **kwargs)

def _build_checked_new_str(self) -> str:
kw_args_str, set_calls_str = build_args_and_assignment_strs(self._field_set, self._defaults)
args_str, set_calls_str = build_args_and_assignment_strs(
self._field_set,
self._defaults,
self._kw_only,
)
check_calls = []
for name, ttype in self._field_set.items():
call_str = build_check_call_str(
Expand All @@ -487,7 +518,7 @@ def _build_checked_new_str(self) -> str:
)

checked_new_str = f"""
def __checked_new__(cls{kw_args_str}):
def __checked_new__(cls{args_str}):
{lazy_imports_str}
{set_calls_str}
return cls.{_NAMED_TUPLE_BASE_NEW_FIELD}(
Expand All @@ -501,9 +532,10 @@ def __checked_new__(cls{kw_args_str}):
def _build_defaults_new(
field_set: Mapping[str, Type],
defaults: Mapping[str, Any],
kw_only: bool,
) -> str:
"""Build a __new__ implementation that handles default values."""
kw_args_str, set_calls_str = build_args_and_assignment_strs(field_set, defaults)
kw_args_str, set_calls_str = build_args_and_assignment_strs(field_set, defaults, kw_only)
assign_str = ",\n ".join([f"{name}={name}" for name in field_set.keys()])
return f"""
def __defaults_new__(cls{kw_args_str}):
Expand All @@ -518,39 +550,40 @@ def __defaults_new__(cls{kw_args_str}):
def build_args_and_assignment_strs(
field_set: Mapping[str, Type],
defaults: Mapping[str, Any],
kw_only: bool,
) -> Tuple[str, str]:
"""Utility funciton shared between _defaults_new and _checked_new to create the arguments to
the function as well as any assignment calls that need to happen.
"""
kw_args = []
args = []
set_calls = []
for arg in field_set.keys():
if arg in defaults:
default = defaults[arg]
if default is None:
kw_args.append(f"{arg} = None")
args.append(f"{arg} = None")
# dont share class instance of default empty containers
elif default == []:
kw_args.append(f"{arg} = None")
args.append(f"{arg} = None")
set_calls.append(f"{arg} = {arg} if {arg} is not None else []")
elif default == {}:
kw_args.append(f"{arg} = None")
args.append(f"{arg} = None")
set_calls.append(f"{arg} = {arg} if {arg} is not None else {'{}'}")
# fallback to direct reference if unknown
else:
kw_args.append(f"{arg} = {_INJECTED_DEFAULT_VALS_LOCAL_VAR}['{arg}']")
args.append(f"{arg} = {_INJECTED_DEFAULT_VALS_LOCAL_VAR}['{arg}']")
else:
kw_args.append(arg)
args.append(arg)

kw_args_str = ""
if kw_args:
kw_args_str = f", *, {', '.join(kw_args)}"
args_str = ""
if args:
args_str = f", {'*,' if kw_only else ''} {', '.join(args)}"

set_calls_str = ""
if set_calls:
set_calls_str = "\n ".join(set_calls)

return kw_args_str, set_calls_str
return args_str, set_calls_str


def _banned_iter(*args, **kwargs):
Expand Down
Loading

0 comments on commit 15bf894

Please sign in to comment.