diff --git a/python_modules/dagster/dagster/_core/definitions/selector.py b/python_modules/dagster/dagster/_core/definitions/selector.py index 5bff31e3a57cd..7f3ce30d3b36c 100644 --- a/python_modules/dagster/dagster/_core/definitions/selector.py +++ b/python_modules/dagster/dagster/_core/definitions/selector.py @@ -130,7 +130,7 @@ def repository_selector(self) -> "RepositorySelector": @whitelist_for_serdes @record -class RepositorySelector(IHaveNew): +class RepositorySelector: location_name: str repository_name: str @@ -167,17 +167,10 @@ def from_graphql_input(graphql_data): ) -@record_custom -class CodeLocationSelector(IHaveNew): +@record(kw_only=False) +class CodeLocationSelector: location_name: str - # allow posargs to avoid breaking change - def __new__(cls, location_name: str): - return super().__new__( - cls, - location_name=location_name, - ) - def to_repository_selector(self) -> RepositorySelector: return RepositorySelector( location_name=self.location_name, @@ -318,21 +311,13 @@ def to_graphql_input(self): } -@record_custom -class PartitionRangeSelector(IHaveNew): +@record(kw_only=False) +class PartitionRangeSelector: """The information needed to resolve a partition range.""" start: str end: str - # allow posargs - def __new__(cls, start: str, end: str): - return super().__new__( - cls, - start=start, - end=end, - ) - def to_graphql_input(self): return { "start": self.start, @@ -347,19 +332,12 @@ def from_graphql_input(graphql_data): ) -@record_custom -class PartitionsSelector(IHaveNew): +@record(kw_only=False) +class PartitionsSelector: """The information needed to define selection partitions.""" ranges: Sequence[PartitionRangeSelector] - # allow posargs - def __new__(cls, ranges: Sequence[PartitionRangeSelector]): - return super().__new__( - cls, - ranges=ranges, - ) - def to_graphql_input(self): return {"ranges": [partition_range.to_graphql_input() for partition_range in self.ranges]} diff --git a/python_modules/dagster/dagster/_core/remote_representation/handle.py b/python_modules/dagster/dagster/_core/remote_representation/handle.py index cc59a21ef5bc2..4807570492c8b 100644 --- a/python_modules/dagster/dagster/_core/remote_representation/handle.py +++ b/python_modules/dagster/dagster/_core/remote_representation/handle.py @@ -11,7 +11,7 @@ RegisteredCodeLocationOrigin, RemoteRepositoryOrigin, ) -from dagster._record import IHaveNew, record, record_custom +from dagster._record import record from dagster._serdes.serdes import whitelist_for_serdes if TYPE_CHECKING: @@ -84,19 +84,11 @@ def for_test( ) -@record_custom -class JobHandle(IHaveNew): +@record(kw_only=False) +class JobHandle: job_name: str repository_handle: RepositoryHandle - # allow posargs - def __new__(cls, job_name: str, repository_handle: RepositoryHandle): - return super().__new__( - cls, - job_name=job_name, - repository_handle=repository_handle, - ) - def to_string(self): return f"{self.location_name}.{self.repository_name}.{self.job_name}" @@ -123,19 +115,11 @@ def to_selector(self) -> JobSubsetSelector: ) -@record_custom -class InstigatorHandle(IHaveNew): +@record(kw_only=False) +class InstigatorHandle: instigator_name: str repository_handle: RepositoryHandle - # allow posargs - def __new__(cls, instigator_name: str, repository_handle: RepositoryHandle): - return super().__new__( - cls, - instigator_name=instigator_name, - repository_handle=repository_handle, - ) - @property def repository_name(self) -> str: return self.repository_handle.repository_name diff --git a/python_modules/dagster/dagster/_core/snap/node.py b/python_modules/dagster/dagster/_core/snap/node.py index e56fe51fc32dd..2da0681b7a665 100644 --- a/python_modules/dagster/dagster/_core/snap/node.py +++ b/python_modules/dagster/dagster/_core/snap/node.py @@ -208,7 +208,7 @@ def get_output_snap(self, name: str) -> OutputDefSnap: }, ) @record -class NodeDefsSnapshot(IHaveNew): +class NodeDefsSnapshot: op_def_snaps: Sequence[OpDefSnap] graph_def_snaps: Sequence[GraphDefSnap] diff --git a/python_modules/dagster/dagster/_record/__init__.py b/python_modules/dagster/dagster/_record/__init__.py index 58e1c5c1e07ed..66d819823a76c 100644 --- a/python_modules/dagster/dagster/_record/__init__.py +++ b/python_modules/dagster/dagster/_record/__init__.py @@ -35,6 +35,7 @@ _NAMED_TUPLE_BASE_NEW_FIELD = "__nt_new__" _REMAPPING_FIELD = "__field_remap__" _ORIGINAL_CLASS_FIELD = "__original_class__" +_KW_ONLY_FIELD = "__kw_only__" _sample_nt = namedtuple("_canary", "x") @@ -44,10 +45,12 @@ def _get_field_set_and_defaults( cls: Type, + kw_only: bool, ) -> Tuple[Mapping[str, Any], Mapping[str, Any]]: field_set = getattr(cls, "__annotations__", {}) defaults = {} + last_defaulted_field = None for name in field_set.keys(): if hasattr(cls, name): attr_val = getattr(cls, name) @@ -57,11 +60,9 @@ def _get_field_set_and_defaults( f"Conflicting non-abstract @property for field {name} on record {cls.__name__}." "Add the the @abstractmethod decorator to make it abstract.", ) - elif isinstance(attr_val, _tuple_getter_type): - # When doing record inheritance, filter out tuplegetters from parents. - # This workaround only seems needed for py3.8 - continue - else: + # When doing record inheritance, filter out tuplegetters from parents. + # This workaround only seems needed for py3.9 + elif not isinstance(attr_val, _tuple_getter_type): check.invariant( not inspect.isfunction(attr_val), f"Conflicting function for field {name} on record {cls.__name__}. " @@ -69,11 +70,25 @@ def _get_field_set_and_defaults( "you will have to override __new__.", ) defaults[name] = attr_val + last_defaulted_field = name + continue + + # fall through here means no default set + if last_defaulted_field and not kw_only: + check.failed( + "Fields without defaults cannot appear after fields with default values. " + f"Field {name} has no default after {last_defaulted_field} with default value." + ) for base in cls.__bases__: if is_record(base): original_base = getattr(base, _ORIGINAL_CLASS_FIELD) - base_field_set, base_defaults = _get_field_set_and_defaults(original_base) + base_kw_only = getattr(base, _KW_ONLY_FIELD) + check.invariant( + kw_only == base_kw_only, + "Can not inherit from a parent @record with different kw_only setting.", + ) + base_field_set, base_defaults = _get_field_set_and_defaults(original_base, kw_only) field_set = {**base_field_set, **field_set} defaults = {**base_defaults, **defaults} @@ -87,13 +102,14 @@ def _namedtuple_record_transform( with_new: bool, decorator_frames: int, field_to_new_mapping: Optional[Mapping[str, str]], + kw_only: bool, ) -> TType: """Transforms the input class in to one that inherits a generated NamedTuple base class and: * bans tuple methods that don't make sense for a record object * creates a run time checked __new__ (optional). """ - field_set, defaults = _get_field_set_and_defaults(cls) + field_set, defaults = _get_field_set_and_defaults(cls, kw_only) base = NamedTuple(f"_{cls.__name__}", field_set.items()) nt_new = base.__new__ @@ -109,7 +125,8 @@ def _namedtuple_record_transform( field_set, defaults, eval_ctx, - 1 if with_new else 0, + new_frames=1 if with_new else 0, + kw_only=kw_only, ) elif defaults: # allow arbitrary ordering of default values by generating a kwarg only __new__ impl @@ -120,7 +137,7 @@ def _namedtuple_record_transform( lazy_imports={}, ) generated_new = eval_ctx.compile_fn( - _build_defaults_new(field_set, defaults), + _build_defaults_new(field_set, defaults, kw_only), _DEFAULTS_NEW, ) @@ -145,6 +162,7 @@ def _namedtuple_record_transform( _NAMED_TUPLE_BASE_NEW_FIELD: nt_new, _REMAPPING_FIELD: field_to_new_mapping or {}, _ORIGINAL_CLASS_FIELD: cls, + _KW_ONLY_FIELD: kw_only, "__reduce__": _reduce, # functools doesn't work, so manually update_wrapper "__module__": cls.__module__, @@ -219,6 +237,7 @@ def record( def record( *, checked: bool = True, + kw_only: bool = True, ) -> Callable[[TType], TType]: ... # Overload for using decorator used with args. @@ -230,11 +249,13 @@ def record( cls: Optional[TType] = None, *, checked: bool = True, + kw_only: bool = True, ) -> Union[TType, Callable[[TType], TType]]: """A class decorator that will create an immutable record class based on the defined fields. Args: - checked: Whether or not to generate runtime type checked construction. + checked: Whether or not to generate runtime type checked construction (default True). + kw_only: Whether or not the generated __new__ is kwargs only (default True). """ if cls: return _namedtuple_record_transform( @@ -243,6 +264,7 @@ def record( with_new=False, decorator_frames=1, field_to_new_mapping=None, + kw_only=kw_only, ) else: return partial( @@ -251,6 +273,7 @@ def record( with_new=False, decorator_frames=0, field_to_new_mapping=None, + kw_only=kw_only, ) @@ -303,6 +326,7 @@ def __new__(cls, name: Optional[str] = None) with_new=True, decorator_frames=1, field_to_new_mapping=field_to_new_mapping, + kw_only=True, ) else: return partial( @@ -311,6 +335,7 @@ def __new__(cls, name: Optional[str] = None) with_new=True, decorator_frames=0, field_to_new_mapping=field_to_new_mapping, + kw_only=True, ) @@ -429,12 +454,14 @@ def __init__( defaults: Mapping[str, Any], eval_ctx: EvalContext, new_frames: int, + kw_only: bool, ): self._field_set = field_set self._defaults = defaults self._eval_ctx = eval_ctx self._new_frames = new_frames # how many frames of __new__ there are self._compiled = False + self._kw_only = kw_only def __call__(self, cls, *args, **kwargs): if _do_defensive_checks(): @@ -470,7 +497,11 @@ def __call__(self, cls, *args, **kwargs): return compiled_fn(cls, *args, **kwargs) def _build_checked_new_str(self) -> str: - kw_args_str, set_calls_str = build_args_and_assignment_strs(self._field_set, self._defaults) + args_str, set_calls_str = build_args_and_assignment_strs( + self._field_set, + self._defaults, + self._kw_only, + ) check_calls = [] for name, ttype in self._field_set.items(): call_str = build_check_call_str( @@ -487,7 +518,7 @@ def _build_checked_new_str(self) -> str: ) checked_new_str = f""" -def __checked_new__(cls{kw_args_str}): +def __checked_new__(cls{args_str}): {lazy_imports_str} {set_calls_str} return cls.{_NAMED_TUPLE_BASE_NEW_FIELD}( @@ -501,9 +532,10 @@ def __checked_new__(cls{kw_args_str}): def _build_defaults_new( field_set: Mapping[str, Type], defaults: Mapping[str, Any], + kw_only: bool, ) -> str: """Build a __new__ implementation that handles default values.""" - kw_args_str, set_calls_str = build_args_and_assignment_strs(field_set, defaults) + kw_args_str, set_calls_str = build_args_and_assignment_strs(field_set, defaults, kw_only) assign_str = ",\n ".join([f"{name}={name}" for name in field_set.keys()]) return f""" def __defaults_new__(cls{kw_args_str}): @@ -518,39 +550,40 @@ def __defaults_new__(cls{kw_args_str}): def build_args_and_assignment_strs( field_set: Mapping[str, Type], defaults: Mapping[str, Any], + kw_only: bool, ) -> Tuple[str, str]: """Utility funciton shared between _defaults_new and _checked_new to create the arguments to the function as well as any assignment calls that need to happen. """ - kw_args = [] + args = [] set_calls = [] for arg in field_set.keys(): if arg in defaults: default = defaults[arg] if default is None: - kw_args.append(f"{arg} = None") + args.append(f"{arg} = None") # dont share class instance of default empty containers elif default == []: - kw_args.append(f"{arg} = None") + args.append(f"{arg} = None") set_calls.append(f"{arg} = {arg} if {arg} is not None else []") elif default == {}: - kw_args.append(f"{arg} = None") + args.append(f"{arg} = None") set_calls.append(f"{arg} = {arg} if {arg} is not None else {'{}'}") # fallback to direct reference if unknown else: - kw_args.append(f"{arg} = {_INJECTED_DEFAULT_VALS_LOCAL_VAR}['{arg}']") + args.append(f"{arg} = {_INJECTED_DEFAULT_VALS_LOCAL_VAR}['{arg}']") else: - kw_args.append(arg) + args.append(arg) - kw_args_str = "" - if kw_args: - kw_args_str = f", *, {', '.join(kw_args)}" + args_str = "" + if args: + args_str = f", {'*,' if kw_only else ''} {', '.join(args)}" set_calls_str = "" if set_calls: set_calls_str = "\n ".join(set_calls) - return kw_args_str, set_calls_str + return args_str, set_calls_str def _banned_iter(*args, **kwargs): diff --git a/python_modules/dagster/dagster_tests/general_tests/test_record.py b/python_modules/dagster/dagster_tests/general_tests/test_record.py index 08af55d1fc1dc..8d15cfa20ca4a 100644 --- a/python_modules/dagster/dagster_tests/general_tests/test_record.py +++ b/python_modules/dagster/dagster_tests/general_tests/test_record.py @@ -364,7 +364,7 @@ class OtherSample: def test_build_args_and_assign(fields, defaults, expected): # tests / documents shared utility fn # don't hesitate to delete this upon refactor - assert build_args_and_assignment_strs(fields, defaults) == expected + assert build_args_and_assignment_strs(fields, defaults, kw_only=True) == expected @record @@ -830,3 +830,58 @@ def boop(self): def test_defensive_checks_running(): # make sure we have enabled defensive checks in test, ideally as broadly as possible assert os.getenv("DAGSTER_RECORD_DEFENSIVE_CHECKS") == "true" + + +def test_allow_posargs(): + @record(kw_only=False) + class Foo: + a: int + + assert Foo(2) + + @record(kw_only=False) + class Bar: + a: int + b: int + c: int = 4 + + assert Bar(1, 2) + + with pytest.raises(CheckError): + + @record(kw_only=False) + class Baz: + a: int = 4 + b: int # type: ignore # good job type checker + + +def test_posargs_inherit(): + @record(kw_only=False) + class Parent: + name: str + + @record(kw_only=False) + class Child(Parent): + parent: Parent + + p = Parent("Alex") + assert p + c = Child("Lyra", p) + assert c + + # test kw_only not being aligned + with pytest.raises(CheckError): + + @record + class Bad(Parent): + other: str + + with pytest.raises(CheckError): + + @record + class A: + a: int + + @record(kw_only=False) + class B(A): + b: int