Skip to content

Commit

Permalink
Merge pull request #251 from specklesystems/gergo/unionTriples
Browse files Browse the repository at this point in the history
fix(type-validation): fix union types with more than 2 arguments
  • Loading branch information
gjedlicska authored Jan 10, 2023
2 parents a32822f + 3db8565 commit ac6ba87
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
11 changes: 6 additions & 5 deletions src/specklepy/objects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,14 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]:

# recursive validation for Unions on both types preferring the fist type
if origin is Union:
t_1, t_2 = t.__args__ # type: ignore
# below is what in nicer for >= py38
# t_1, t_2 = get_args(t)
t_1_success, t_1_value = _validate_type(t_1, value)
if t_1_success:
return True, t_1_value
return _validate_type(t_2, value)
args = t.__args__ # type: ignore
for arg_t in args:
t_success, t_value = _validate_type(arg_t, value)
if t_success:
return True, t_value
return False, value
if origin is dict:
if not isinstance(value, dict):
return False, value
Expand Down
22 changes: 21 additions & 1 deletion tests/unit/test_type_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum, IntEnum
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import pytest

Expand All @@ -18,6 +18,16 @@ class FakeIntEnum(IntEnum):
one = 1


class FakeBase(Base):
foo: Optional[str]

def __init__(self, foo: str) -> None:
self.foo = foo


fake_bases = [FakeBase("foo"), FakeBase("bar")]


@pytest.mark.parametrize(
"input_type, value, is_valid, return_value",
[
Expand Down Expand Up @@ -78,6 +88,16 @@ class FakeIntEnum(IntEnum):
# given our current rules, this is the reality. Its just sad...
(Tuple[str, str, str], (1, "foo", "bar"), True, ("1", "foo", "bar")),
(Tuple[str, Optional[str], str], (1, None, "bar"), True, ("1", None, "bar")),
(Optional[Union[List[int], List[FakeBase]]], None, True, None),
(Optional[Union[List[int], List[FakeBase]]], "foo", False, "foo"),
(Union[List[int], List[FakeBase], None], "foo", False, "foo"),
(Optional[Union[List[int], List[FakeBase]]], [1, 2, 3], True, [1, 2, 3]),
(
Optional[Union[List[int], List[FakeBase]]],
fake_bases,
True,
fake_bases,
),
],
)
def test_validate_type(
Expand Down

0 comments on commit ac6ba87

Please sign in to comment.