diff --git a/cognite/client/config.py b/cognite/client/config.py index cc89bd2020..e2211febb7 100644 --- a/cognite/client/config.py +++ b/cognite/client/config.py @@ -58,6 +58,12 @@ def __init__(self) -> None: self.proxies: dict[str, str] | None = {} self.max_workers: int = 5 self.silence_feature_preview_warnings: bool = False + self.enable_runtime_type_checking: bool = False + if self.enable_runtime_type_checking: + FutureWarning( + "Experimental runtime type checking is enabled. This feature will only work for " + "Python 3.10 and above." + ) def apply_settings(self, settings: dict[str, Any] | str) -> None: """Apply settings to the global configuration object from a YAML/JSON string or dict. diff --git a/cognite/client/exceptions.py b/cognite/client/exceptions.py index d9dfff4d94..69d63196dc 100644 --- a/cognite/client/exceptions.py +++ b/cognite/client/exceptions.py @@ -16,6 +16,9 @@ class CogniteException(Exception): pass +class CogniteTypeError(CogniteException): ... + + @dataclass class GraphQLErrorSpec: message: str diff --git a/cognite/client/utils/_runtime_type_checking.py b/cognite/client/utils/_runtime_type_checking.py new file mode 100644 index 0000000000..1817bf5436 --- /dev/null +++ b/cognite/client/utils/_runtime_type_checking.py @@ -0,0 +1,33 @@ +import sys +from inspect import isfunction +from typing import Any, Callable, TypeVar + +from beartype import beartype +from beartype.roar import BeartypeCallHintParamViolation + +from cognite.client import global_config +from cognite.client.exceptions import CogniteTypeError + +T_Callable = TypeVar("T_Callable", bound=Callable) +T_Class = TypeVar("T_Class", bound=type) + + +def runtime_type_checked_method(f: T_Callable) -> T_Callable: + if (sys.version_info < (3, 10)) or not global_config.enable_runtime_type_checking: + return f + beartyped_f = beartype(f) + + def f_wrapped(*args: Any, **kwargs: Any) -> Any: + try: + return beartyped_f(*args, **kwargs) + except BeartypeCallHintParamViolation as e: + raise CogniteTypeError(e.args[0]) + + return f_wrapped + + +def runtime_type_checked(c: T_Class) -> T_Class: + for name in dir(c): + if not name.startswith("_") or name == "__init__" and isfunction(getattr(c, name)): + setattr(c, name, runtime_type_checked_method(getattr(c, name))) + return c diff --git a/tests/tests_unit/test_utils/test_runtime_type_checking.py b/tests/tests_unit/test_utils/test_runtime_type_checking.py new file mode 100644 index 0000000000..5242dc9a90 --- /dev/null +++ b/tests/tests_unit/test_utils/test_runtime_type_checking.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import re +import sys +from dataclasses import dataclass +from typing import overload + +import pytest + +from cognite.client import global_config +from cognite.client.exceptions import CogniteTypeError +from cognite.client.utils._runtime_type_checking import runtime_type_checked + +pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")] + + +global_config.enable_runtime_type_checking = True + + +class Foo: ... + + +class TestTypes: + @runtime_type_checked + class Types: + def primitive(self, x: int) -> None: ... + + def list(self, x: list[str]) -> None: ... + + def custom_class(self, x: Foo) -> None: ... + + def test_primitive(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.Types().primitive("1") + + self.Types().primitive(1) + + def test_list(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' " + "violates type hint list[str], as str '1' not instance of list." + ), + ): + self.Types().list("1") + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] " + "violates type hint list[str], as list index 0 item int 1 not instance of str." + ), + ): + self.Types().list([1]) + + self.Types().list(["ok"]) + + def test_custom_type(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() " + "parameter x='1' violates type hint " + ", as str '1' not instance " + 'of ' + ), + ): + self.Types().custom_class("1") + + self.Types().custom_class(Foo()) + + @runtime_type_checked + class ClassWithConstructor: + def __init__(self, x: int, y: str) -> None: + self.x = x + self.y = y + + def test_constructor_for_class(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.ClassWithConstructor("1", "1") + + def test_constructor_for_subclass(self) -> None: + class SubDataClassWithConstructor(self.ClassWithConstructor): + pass + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + SubDataClassWithConstructor("1", "1") + + @runtime_type_checked + @dataclass + class DataClassWithConstructor: + x: int + y: int + + def test_constructor_for_dataclass(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.DataClassWithConstructor("1", "1") + + def test_constructor_for_dataclass_subclass(self) -> None: + class SubDataClassWithConstructor(self.DataClassWithConstructor): + pass + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + SubDataClassWithConstructor("1", "1") + + +class TestOverloads: + @runtime_type_checked + class WithOverload: + @overload + def foo(self, x: int, y: int) -> str: ... + + @overload + def foo(self, x: str, y: str) -> str: ... + + def foo(self, x: int | str, y: int | str) -> str: + return f"{x}{y}" + + def test_overloads(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() " + "parameter y=1.0 violates type hint int | str, as float 1.0 not str or int." + ), + ): + self.WithOverload().foo(1, 1.0) + + # Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet + self.WithOverload().foo(1, "1")