-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d587246
commit 853611d
Showing
3 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import sys | ||
from inspect import isfunction | ||
from typing import Any, Callable, TypeVar | ||
|
||
from beartype import beartype | ||
from beartype.roar import BeartypeCallHintParamViolation | ||
|
||
from cognite.client.exceptions import CogniteTypeError | ||
|
||
T_Callable = TypeVar("T_Callable", bound=Callable) | ||
T_Class = TypeVar("T_Class", bound=type) | ||
|
||
|
||
class Settings: | ||
enable_runtime_type_checking: bool = False | ||
|
||
|
||
def runtime_type_checked_method(f: T_Callable) -> T_Callable: | ||
if (sys.version_info < (3, 10)) or not Settings.enable_runtime_type_checking: | ||
return f | ||
beartyped_f = beartype(f) | ||
|
||
def f_wrapped(*args: Any, **kwargs: Any) -> Any: | ||
try: | ||
return beartyped_f(*args, **kwargs) | ||
except BeartypeCallHintParamViolation as e: | ||
raise CogniteTypeError(e.args[0]) | ||
|
||
return f_wrapped | ||
|
||
|
||
def runtime_type_checked(c: T_Class) -> T_Class: | ||
for name in dir(c): | ||
if not name.startswith("_") or name == "__init__" and isfunction(getattr(c, name)): | ||
setattr(c, name, runtime_type_checked_method(getattr(c, name))) | ||
return c |
160 changes: 160 additions & 0 deletions
160
tests/tests_unit/test_utils/test_runtime_type_checking.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
import sys | ||
from dataclasses import dataclass | ||
from typing import overload | ||
|
||
import pytest | ||
|
||
from cognite.client.exceptions import CogniteTypeError | ||
from cognite.client.utils._runtime_type_checking import Settings, runtime_type_checked | ||
|
||
pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")] | ||
|
||
|
||
Settings.enable_runtime_type_checking = True | ||
|
||
|
||
class Foo: ... | ||
|
||
|
||
class TestTypes: | ||
@runtime_type_checked | ||
class Types: | ||
def primitive(self, x: int) -> None: ... | ||
|
||
def list(self, x: list[str]) -> None: ... | ||
|
||
def custom_class(self, x: Foo) -> None: ... | ||
|
||
def test_primitive(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
self.Types().primitive("1") | ||
|
||
self.Types().primitive(1) | ||
|
||
def test_list(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' " | ||
"violates type hint list[str], as str '1' not instance of list." | ||
), | ||
): | ||
self.Types().list("1") | ||
|
||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] " | ||
"violates type hint list[str], as list index 0 item int 1 not instance of str." | ||
), | ||
): | ||
self.Types().list([1]) | ||
|
||
self.Types().list(["ok"]) | ||
|
||
def test_custom_type(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() " | ||
"parameter x='1' violates type hint " | ||
"<class 'tests.tests_unit.test_utils.test_runtime_type_checking.Foo'>, as str '1' not instance " | ||
'of <class "tests.tests_unit.test_utils.test_runtime_type_checking.Foo">' | ||
), | ||
): | ||
self.Types().custom_class("1") | ||
|
||
self.Types().custom_class(Foo()) | ||
|
||
@runtime_type_checked | ||
class ClassWithConstructor: | ||
def __init__(self, x: int, y: str) -> None: | ||
self.x = x | ||
self.y = y | ||
|
||
def test_constructor_for_class(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
self.ClassWithConstructor("1", "1") | ||
|
||
def test_constructor_for_subclass(self) -> None: | ||
class SubDataClassWithConstructor(self.ClassWithConstructor): | ||
pass | ||
|
||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
SubDataClassWithConstructor("1", "1") | ||
|
||
@runtime_type_checked | ||
@dataclass | ||
class DataClassWithConstructor: | ||
x: int | ||
y: int | ||
|
||
def test_constructor_for_dataclass(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
self.DataClassWithConstructor("1", "1") | ||
|
||
def test_constructor_for_dataclass_subclass(self) -> None: | ||
class SubDataClassWithConstructor(self.DataClassWithConstructor): | ||
pass | ||
|
||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " | ||
"parameter x='1' violates type hint <class 'int'>, as str '1' not instance of int." | ||
), | ||
): | ||
SubDataClassWithConstructor("1", "1") | ||
|
||
|
||
class TestOverloads: | ||
@runtime_type_checked | ||
class WithOverload: | ||
@overload | ||
def foo(self, x: int, y: int) -> str: ... | ||
|
||
@overload | ||
def foo(self, x: str, y: str) -> str: ... | ||
|
||
def foo(self, x: int | str, y: int | str) -> str: | ||
return f"{x}{y}" | ||
|
||
def test_overloads(self) -> None: | ||
with pytest.raises( | ||
CogniteTypeError, | ||
match=re.escape( | ||
"Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() " | ||
"parameter y=1.0 violates type hint int | str, as float 1.0 not str or int." | ||
), | ||
): | ||
self.WithOverload().foo(1, 1.0) | ||
|
||
# Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet | ||
self.WithOverload().foo(1, "1") |