From ba9e06827ce29caacd12ad70ad3954b5825c8a75 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 21 Jun 2023 16:14:51 -0700 Subject: [PATCH] don't treat dataclasses with decoding fns as dataclasses, fixes #266 --- .../helpers/serialization/decoding.py | 5 ++ simple_parsing/wrappers/dataclass_wrapper.py | 15 ++++-- simple_parsing/wrappers/field_wrapper.py | 3 +- test/test_issue_266.py | 53 +++++++++++++++++++ 4 files changed, 70 insertions(+), 6 deletions(-) create mode 100644 test/test_issue_266.py diff --git a/simple_parsing/helpers/serialization/decoding.py b/simple_parsing/helpers/serialization/decoding.py index 2d3aab73..734a05eb 100644 --- a/simple_parsing/helpers/serialization/decoding.py +++ b/simple_parsing/helpers/serialization/decoding.py @@ -77,6 +77,11 @@ def _wrapper(fn: C) -> C: return _wrapper +def has_custom_decode_fn(some_type: type) -> bool: + """Returns True if the given type has a custom decoding function registered.""" + return some_type in _decoding_fns + + @decoding_fn_for_type(int) def _decode_int(v: str) -> int: int_v = int(v) diff --git a/simple_parsing/wrappers/dataclass_wrapper.py b/simple_parsing/wrappers/dataclass_wrapper.py index 2c00c98d..a9f9fe89 100644 --- a/simple_parsing/wrappers/dataclass_wrapper.py +++ b/simple_parsing/wrappers/dataclass_wrapper.py @@ -14,6 +14,7 @@ from typing_extensions import Literal from .. import docstring, utils +from ..helpers.serialization import has_custom_decode_fn from ..utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type from .field_wrapper import FieldWrapper from .wrapper import Wrapper @@ -133,7 +134,9 @@ def __init__( self.fields.append(field_wrapper) - elif dataclasses.is_dataclass(field_type) and field.default is not None: + elif dataclasses.is_dataclass(field_type) and\ + field.default is not None and\ + not has_custom_decode_fn(field_type): # Non-optional dataclass field. # handle a nested dataclass attribute dataclass, name = field_type, field.name @@ -150,9 +153,11 @@ def __init__( ) self._children.append(child_wrapper) - elif utils.contains_dataclass_type_arg(field_type): - # Extract the dataclass type from the annotation of the field. - field_dataclass = utils.get_dataclass_type_arg(field_type) + # See if it's a generic with a dataclass arg + # if so, extract the dataclass type from the annotation of the field. + elif (field_dataclass := utils.get_dataclass_type_arg(field_type)) is not None and\ + field.default is not None and\ + not has_custom_decode_fn(field_dataclass): # todo: Figure out if this is still necessary, or if `field_default` can be handled # the same way as above. if field_default is dataclasses.MISSING: @@ -288,7 +293,7 @@ def default(self) -> DataclassT | None: def set_default(self, value: DataclassT | dict | None): """Sets the default values for the arguments of the fields of this dataclass.""" - if value is not None and not isinstance(value, dict): + if value is not None and dataclasses.is_dataclass(value): field_default_values = dataclasses.asdict(value) else: field_default_values = value diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 0becac2b..0b6f0daa 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -15,6 +15,7 @@ from .. import docstring, utils from ..helpers.custom_actions import BooleanOptionalAction +from ..helpers.serialization import has_custom_decode_fn from ..utils import Dataclass from .field_metavar import get_metavar from .field_parsing import get_parsing_fn @@ -517,7 +518,7 @@ def postprocess(self, raw_parsed_value: Any) -> Any: # TODO: Make sure that this doesn't cause issues with NamedTuple types. return tuple(raw_parsed_value) - elif self.type not in utils.builtin_types: + elif self.type not in utils.builtin_types and not has_custom_decode_fn(self.type): # TODO: what if we actually got an auto-generated parsing function? try: # if the field has a weird type, we try to call it directly. diff --git a/test/test_issue_266.py b/test/test_issue_266.py new file mode 100644 index 00000000..7c0d56db --- /dev/null +++ b/test/test_issue_266.py @@ -0,0 +1,53 @@ +import dataclasses +import tempfile + +import simple_parsing +from simple_parsing.helpers.serialization import register_decoding_fn, get_decoding_fn + + +@dataclasses.dataclass +class Id: + value: str + + def __post_init__(self): + assert isinstance(self.value, str) + + +class NonDCId: + def __init__(self, value: str): + assert isinstance(value, str) + self.value = value + + def __eq__(self, other): + return self.value == other.value + + +register_decoding_fn(Id, (lambda x, drop_extra_fields: Id(value=x))) +register_decoding_fn(NonDCId, (lambda x: NonDCId(value=x))) + + +@dataclasses.dataclass +class Person: + name: str + other_id: NonDCId + id: Id + + +def test_parse_helper_uses_custom_decoding_fn(): + config_str = """ + name: bob + id: hi + other_id: hello + """ + + # ok + assert get_decoding_fn(Person)({"name": "bob", "id": "hi", "other_id": "hello"}) == Person( + "bob", NonDCId("hello"), Id("hi") + ) # type: ignore + + with tempfile.NamedTemporaryFile("w", suffix=".yaml") as f: + f.write(config_str) + f.flush() + + parsed = simple_parsing.parse(Person, f.name, args=[]) + assert parsed == Person("bob", NonDCId("hello"), Id("hi")) # type: ignore