Skip to content

don't treat dataclasses with decoding fns as dataclasses, fixes #266 #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions simple_parsing/helpers/serialization/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def _wrapper(fn: C) -> C:
return _wrapper


def has_custom_decode_fn(some_type: type) -> bool:
"""Returns True if the given type has a custom decoding function registered."""
return some_type in _decoding_fns


@decoding_fn_for_type(int)
def _decode_int(v: str) -> int:
int_v = int(v)
Expand Down
15 changes: 10 additions & 5 deletions simple_parsing/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import Literal

from .. import docstring, utils
from ..helpers.serialization import has_custom_decode_fn
from ..utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type
from .field_wrapper import FieldWrapper
from .wrapper import Wrapper
Expand Down Expand Up @@ -133,7 +134,9 @@ def __init__(

self.fields.append(field_wrapper)

elif dataclasses.is_dataclass(field_type) and field.default is not None:
elif dataclasses.is_dataclass(field_type) and\
field.default is not None and\
not has_custom_decode_fn(field_type):
# Non-optional dataclass field.
# handle a nested dataclass attribute
dataclass, name = field_type, field.name
Expand All @@ -150,9 +153,11 @@ def __init__(
)
self._children.append(child_wrapper)

elif utils.contains_dataclass_type_arg(field_type):
# Extract the dataclass type from the annotation of the field.
field_dataclass = utils.get_dataclass_type_arg(field_type)
# See if it's a generic with a dataclass arg
# if so, extract the dataclass type from the annotation of the field.
elif (field_dataclass := utils.get_dataclass_type_arg(field_type)) is not None and\
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't use the walrus operator unfortunately, we want python 3.7+ support.

field.default is not None and\
not has_custom_decode_fn(field_dataclass):
# todo: Figure out if this is still necessary, or if `field_default` can be handled
# the same way as above.
if field_default is dataclasses.MISSING:
Expand Down Expand Up @@ -288,7 +293,7 @@ def default(self) -> DataclassT | None:

def set_default(self, value: DataclassT | dict | None):
"""Sets the default values for the arguments of the fields of this dataclass."""
if value is not None and not isinstance(value, dict):
if value is not None and dataclasses.is_dataclass(value):
field_default_values = dataclasses.asdict(value)
else:
field_default_values = value
Expand Down
3 changes: 2 additions & 1 deletion simple_parsing/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .. import docstring, utils
from ..helpers.custom_actions import BooleanOptionalAction
from ..helpers.serialization import has_custom_decode_fn
from ..utils import Dataclass
from .field_metavar import get_metavar
from .field_parsing import get_parsing_fn
Expand Down Expand Up @@ -517,7 +518,7 @@ def postprocess(self, raw_parsed_value: Any) -> Any:
# TODO: Make sure that this doesn't cause issues with NamedTuple types.
return tuple(raw_parsed_value)

elif self.type not in utils.builtin_types:
elif self.type not in utils.builtin_types and not has_custom_decode_fn(self.type):
# TODO: what if we actually got an auto-generated parsing function?
try:
# if the field has a weird type, we try to call it directly.
Expand Down
53 changes: 53 additions & 0 deletions test/test_issue_266.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import dataclasses
import tempfile

import simple_parsing
from simple_parsing.helpers.serialization import register_decoding_fn, get_decoding_fn


@dataclasses.dataclass
class Id:
value: str

def __post_init__(self):
assert isinstance(self.value, str)


class NonDCId:
def __init__(self, value: str):
assert isinstance(value, str)
self.value = value

def __eq__(self, other):
return self.value == other.value


register_decoding_fn(Id, (lambda x, drop_extra_fields: Id(value=x)))
register_decoding_fn(NonDCId, (lambda x: NonDCId(value=x)))


@dataclasses.dataclass
class Person:
name: str
other_id: NonDCId
id: Id


def test_parse_helper_uses_custom_decoding_fn():
config_str = """
name: bob
id: hi
other_id: hello
"""

# ok
assert get_decoding_fn(Person)({"name": "bob", "id": "hi", "other_id": "hello"}) == Person(
"bob", NonDCId("hello"), Id("hi")
) # type: ignore

with tempfile.NamedTemporaryFile("w", suffix=".yaml") as f:
f.write(config_str)
f.flush()

parsed = simple_parsing.parse(Person, f.name, args=[])
assert parsed == Person("bob", NonDCId("hello"), Id("hi")) # type: ignore