-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d587246
commit 8bdc2ce
Showing
4 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import sys | ||
from inspect import isfunction | ||
from typing import Any, Callable, TypeVar | ||
|
||
from beartype import beartype | ||
from beartype.roar import BeartypeCallHintParamViolation | ||
|
||
from cognite.client import global_config | ||
from cognite.client.exceptions import CogniteTypeError | ||
|
||
T_Callable = TypeVar("T_Callable", bound=Callable) | ||
T_Class = TypeVar("T_Class", bound=type) | ||
|
||
|
||
def runtime_type_checked_method(f: T_Callable) -> T_Callable: | ||
if (sys.version_info < (3, 10)) or not global_config.enable_runtime_type_checking: | ||
return f | ||
beartyped_f = beartype(f) | ||
|
||
def f_wrapped(*args: Any, **kwargs: Any) -> Any: | ||
try: | ||
return beartyped_f(*args, **kwargs) | ||
except BeartypeCallHintParamViolation as e: | ||
raise CogniteTypeError(e.args[0]) | ||
|
||
return f_wrapped | ||
|
||
|
||
def runtime_type_checked(c: T_Class) -> T_Class: | ||
for name in dir(c): | ||
if not name.startswith("_") or name == "__init__" and isfunction(getattr(c, name)): | ||
setattr(c, name, runtime_type_checked_method(getattr(c, name))) | ||
return c |
161 changes: 161 additions & 0 deletions
161
tests/tests_unit/test_utils/test_runtime_type_checking.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
import sys | ||
from dataclasses import dataclass | ||
from typing import overload | ||
|
||
import pytest | ||
|
||
from cognite.client import global_config | ||
from cognite.client.exceptions import CogniteTypeError | ||
from cognite.client.utils._runtime_type_checking import runtime_type_checked | ||
|
||
pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")] | ||
|
||
|
||
global_config.enable_runtime_type_checking = True | ||
|
||
|
||
class Foo: ... | ||
|
||
|
||
class TestTypes: | ||
@runtime_type_checked | ||
class Types: | ||
def primitive(self, x: int) -> None: ... | ||
|
||
def list(self, x: list[str]) -> None: ... | ||
|
||
def custom_class(self, x: Foo) -> None: ... | ||
|
||
def test_primitive(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
self.Types().primitive("1") | ||
|
||
self.Types().primitive(1) | ||
|
||
def test_list(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' " | ||
"violates type hint list[str], as str '1' not instance of list." | ||
), | ||
): | ||
self.Types().list("1") | ||
|
||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] " | ||
"violates type hint list[str], as list index 0 item int 1 not instance of str." | ||
), | ||
): | ||
self.Types().list([1]) | ||
|
||
self.Types().list(["ok"]) | ||
|
||
def test_custom_type(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() " | ||
"parameter x='1' violates type hint " | ||
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance " | ||
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">' | ||
), | ||
): | ||
self.Types().custom_class("1") | ||
|
||
self.Types().custom_class(Foo()) | ||
|
||
@runtime_type_checked | ||
class ClassWithConstructor: | ||
def __init__(self, x: int, y: str) -> None: | ||
self.x = x | ||
self.y = y | ||
|
||
def test_constructor_for_class(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
self.ClassWithConstructor("1", "1") | ||
|
||
def test_constructor_for_subclass(self) -> None: | ||
class SubDataClassWithConstructor(self.ClassWithConstructor): | ||
pass | ||
|
||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
SubDataClassWithConstructor("1", "1") | ||
|
||
@runtime_type_checked | ||
@dataclass | ||
class DataClassWithConstructor: | ||
x: int | ||
y: int | ||
|
||
def test_constructor_for_dataclass(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
self.DataClassWithConstructor("1", "1") | ||
|
||
def test_constructor_for_dataclass_subclass(self) -> None: | ||
class SubDataClassWithConstructor(self.DataClassWithConstructor): | ||
pass | ||
|
||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
SubDataClassWithConstructor("1", "1") | ||
|
||
|
||
class TestOverloads: | ||
@runtime_type_checked | ||
class WithOverload: | ||
@overload | ||
def foo(self, x: int, y: int) -> str: ... | ||
|
||
@overload | ||
def foo(self, x: str, y: str) -> str: ... | ||
|
||
def foo(self, x: int | str, y: int | str) -> str: | ||
return f"{x}{y}" | ||
|
||
def test_overloads(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() " | ||
"parameter y=1.0 violates type hint int | str, as float 1.0 not str or int." | ||
), | ||
): | ||
self.WithOverload().foo(1, 1.0) | ||
|
||
# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet | ||
self.WithOverload().foo(1, "1") |