diff --git a/cognite/client/exceptions.py b/cognite/client/exceptions.py index d9dfff4d94..69d63196dc 100644 --- a/cognite/client/exceptions.py +++ b/cognite/client/exceptions.py @@ -16,6 +16,9 @@ class CogniteException(Exception): pass +class CogniteTypeError(CogniteException): ... + + @dataclass class GraphQLErrorSpec: message: str diff --git a/cognite/client/utils/_runtime_type_checking.py b/cognite/client/utils/_runtime_type_checking.py new file mode 100644 index 0000000000..6a37484484 --- /dev/null +++ b/cognite/client/utils/_runtime_type_checking.py @@ -0,0 +1,36 @@ +import sys +from inspect import isfunction +from typing import Any, Callable, TypeVar + +from beartype import beartype +from beartype.roar import BeartypeCallHintParamViolation + +from cognite.client.exceptions import CogniteTypeError + +T_Callable = TypeVar("T_Callable", bound=Callable) +T_Class = TypeVar("T_Class", bound=type) + + +class Settings: + enable_runtime_type_checking: bool = False + + +def runtime_type_checked_method(f: T_Callable) -> T_Callable: + if (sys.version_info < (3, 10)) or not Settings.enable_runtime_type_checking: + return f + beartyped_f = beartype(f) + + def f_wrapped(*args: Any, **kwargs: Any) -> Any: + try: + return beartyped_f(*args, **kwargs) + except BeartypeCallHintParamViolation as e: + raise CogniteTypeError(e.args[0]) + + return f_wrapped + + +def runtime_type_checked(c: T_Class) -> T_Class: + for name in dir(c): + if not name.startswith("_") or name == "__init__" and isfunction(getattr(c, name)): + setattr(c, name, runtime_type_checked_method(getattr(c, name))) + return c diff --git a/poetry.lock b/poetry.lock index 274a3867fe..ad0934c2e8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -130,6 +130,24 @@ files = [ [package.extras] tzdata = ["tzdata"] +[[package]] +name = "beartype" +version = "0.18.5" +description = "Unbearably fast runtime type checking in pure Python." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "beartype-0.18.5-py3-none-any.whl", hash = "sha256:5301a14f2a9a5540fe47ec6d34d758e9cd8331d36c4760fc7a5499ab86310089"}, + {file = "beartype-0.18.5.tar.gz", hash = "sha256:264ddc2f1da9ec94ff639141fbe33d22e12a9f75aa863b83b7046ffff1381927"}, +] + +[package.extras] +all = ["typing-extensions (>=3.10.0.0)"] +dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "equinox", "mypy (>=0.800)", "numpy", "pandera", "pydata-sphinx-theme (<=0.7.2)", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "tox (>=3.20.1)", "typing-extensions (>=3.10.0.0)"] +doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"] +test-tox = ["equinox", "mypy (>=0.800)", "numpy", "pandera", "pytest (>=4.0.0)", "sphinx", "typing-extensions (>=3.10.0.0)"] +test-tox-coverage = ["coverage (>=5.5)"] + [[package]] name = "certifi" version = "2024.8.30" @@ -603,19 +621,19 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth [[package]] name = "filelock" -version = "3.15.4" +version = "3.16.0" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, - {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, + {file = "filelock-3.16.0-py3-none-any.whl", hash = "sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609"}, + {file = "filelock-3.16.0.tar.gz", hash = "sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] -typing = ["typing-extensions (>=4.8)"] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.1.1)", "pytest (>=8.3.2)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.3)"] +typing = ["typing-extensions (>=4.12.2)"] [[package]] name = "fiona" @@ -801,13 +819,13 @@ test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "p [[package]] name = "importlib-resources" -version = "6.4.4" +version = "6.4.5" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.4.4-py3-none-any.whl", hash = "sha256:dda242603d1c9cd836c3368b1174ed74cb4049ecd209e7a1a0104620c18c5c11"}, - {file = "importlib_resources-6.4.4.tar.gz", hash = "sha256:20600c8b7361938dc0bb2d5ec0297802e575df486f5a544fa414da65e13721f7"}, + {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, + {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, ] [package.dependencies] @@ -1161,22 +1179,22 @@ tests = ["pytest (>=4.6)"] [[package]] name = "msal" -version = "1.30.0" +version = "1.31.0" description = "The Microsoft Authentication Library (MSAL) for Python library enables your app to access the Microsoft Cloud by supporting authentication of users with Microsoft Azure Active Directory accounts (AAD) and Microsoft Accounts (MSA) using industry standard OAuth2 and OpenID Connect." optional = false python-versions = ">=3.7" files = [ - {file = "msal-1.30.0-py3-none-any.whl", hash = "sha256:423872177410cb61683566dc3932db7a76f661a5d2f6f52f02a047f101e1c1de"}, - {file = "msal-1.30.0.tar.gz", hash = "sha256:b4bf00850092e465157d814efa24a18f788284c9a479491024d62903085ea2fb"}, + {file = "msal-1.31.0-py3-none-any.whl", hash = "sha256:96bc37cff82ebe4b160d5fc0f1196f6ca8b50e274ecd0ec5bf69c438514086e7"}, + {file = "msal-1.31.0.tar.gz", hash = "sha256:2c4f189cf9cc8f00c80045f66d39b7c0f3ed45873fd3d1f2af9f22db2e12ff4b"}, ] [package.dependencies] -cryptography = ">=2.5,<45" +cryptography = ">=2.5,<46" PyJWT = {version = ">=1.0.0,<3", extras = ["crypto"]} requests = ">=2.0.0,<3" [package.extras] -broker = ["pymsalruntime (>=0.13.2,<0.17)"] +broker = ["pymsalruntime (>=0.14,<0.18)", "pymsalruntime (>=0.17,<0.18)"] [[package]] name = "mypy" @@ -1598,19 +1616,19 @@ testing = ["pytest", "pytest-cov", "wheel"] [[package]] name = "platformdirs" -version = "4.2.2" +version = "4.3.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, - {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, + {file = "platformdirs-4.3.2-py3-none-any.whl", hash = "sha256:eb1c8582560b34ed4ba105009a4badf7f6f85768b30126f351328507b2beb617"}, + {file = "platformdirs-4.3.2.tar.gz", hash = "sha256:9e5e27a08aa095dd127b9f2e764d74254f482fef22b0970773bfba79d091ab8c"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] -type = ["mypy (>=1.8)"] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] [[package]] name = "pluggy" @@ -2652,13 +2670,13 @@ files = [ [[package]] name = "types-requests" -version = "2.32.0.20240905" +version = "2.32.0.20240907" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" files = [ - {file = "types-requests-2.32.0.20240905.tar.gz", hash = "sha256:e97fd015a5ed982c9ddcd14cc4afba9d111e0e06b797c8f776d14602735e9bd6"}, - {file = "types_requests-2.32.0.20240905-py3-none-any.whl", hash = "sha256:f46ecb55f5e1a37a58be684cf3f013f166da27552732ef2469a0cc8e62a72881"}, + {file = "types-requests-2.32.0.20240907.tar.gz", hash = "sha256:ff33935f061b5e81ec87997e91050f7b4af4f82027a7a7a9d9aaea04a963fdf8"}, + {file = "types_requests-2.32.0.20240907-py3-none-any.whl", hash = "sha256:1d1e79faeaf9d42def77f3c304893dea17a97cae98168ac69f3cb465516ee8da"}, ] [package.dependencies] @@ -2716,13 +2734,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.26.3" +version = "20.26.4" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, - {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, + {file = "virtualenv-20.26.4-py3-none-any.whl", hash = "sha256:48f2695d9809277003f30776d155615ffc11328e6a0a8c1f0ec80188d7874a55"}, + {file = "virtualenv-20.26.4.tar.gz", hash = "sha256:c17f4e0f3e6036e9f26700446f85c76ab11df65ff6d8a9cbfad9f71aabfcf23c"}, ] [package.dependencies] @@ -2777,4 +2795,4 @@ yaml = ["PyYAML"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "9e456775447549841733cf93df8ec207a099cbffc1e752fccfaa5acdb98d2da9" +content-hash = "e0cdd289fc36b45cd74d1dbc02539b198336b3cde8550df73f47c6a8e8c5005e" diff --git a/pyproject.toml b/pyproject.toml index ac15cc6ad4..89345baa56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ shapely = { version = ">=1.7.0", optional = true } pyodide-http = { version = "^0.2.1", optional = true } graphlib-backport = { version = "^1.0.0", python = "<3.9" } PyYAML = { version = "^6.0", optional = true } +beartype = "^0" [tool.poetry.extras] pandas = ["pandas"] diff --git a/tests/tests_unit/test_utils/test_runtime_type_checking.py b/tests/tests_unit/test_utils/test_runtime_type_checking.py new file mode 100644 index 0000000000..2b86ea42de --- /dev/null +++ b/tests/tests_unit/test_utils/test_runtime_type_checking.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import re +import sys +from dataclasses import dataclass +from typing import overload + +import pytest + +from cognite.client.exceptions import CogniteTypeError +from cognite.client.utils._runtime_type_checking import Settings, runtime_type_checked + +pytestmark = [pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10")] + + +Settings.enable_runtime_type_checking = True + + +class Foo: ... + + +class TestTypes: + @runtime_type_checked + class Types: + def primitive(self, x: int) -> None: ... + + def list(self, x: list[str]) -> None: ... + + def custom_class(self, x: Foo) -> None: ... + + def test_primitive(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.primitive() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.Types().primitive("1") + + self.Types().primitive(1) + + def test_list(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x='1' " + "violates type hint list[str], as str '1' not instance of list." + ), + ): + self.Types().list("1") + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.list() parameter x=[1] " + "violates type hint list[str], as list index 0 item int 1 not instance of str." + ), + ): + self.Types().list([1]) + + self.Types().list(["ok"]) + + def test_custom_type(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.Types.custom_class() " + "parameter x='1' violates type hint " + ", as str '1' not instance " + 'of ' + ), + ): + self.Types().custom_class("1") + + self.Types().custom_class(Foo()) + + @runtime_type_checked + class ClassWithConstructor: + def __init__(self, x: int, y: str) -> None: + self.x = x + self.y = y + + def test_constructor_for_class(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.ClassWithConstructor("1", "1") + + def test_constructor_for_subclass(self) -> None: + class SubDataClassWithConstructor(self.ClassWithConstructor): + pass + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.ClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + SubDataClassWithConstructor("1", "1") + + @runtime_type_checked + @dataclass + class DataClassWithConstructor: + x: int + y: int + + def test_constructor_for_dataclass(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + self.DataClassWithConstructor("1", "1") + + def test_constructor_for_dataclass_subclass(self) -> None: + class SubDataClassWithConstructor(self.DataClassWithConstructor): + pass + + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestTypes.DataClassWithConstructor.__init__() " + "parameter x='1' violates type hint , as str '1' not instance of int." + ), + ): + SubDataClassWithConstructor("1", "1") + + +class TestOverloads: + @runtime_type_checked + class WithOverload: + @overload + def foo(self, x: int, y: int) -> str: ... + + @overload + def foo(self, x: str, y: str) -> str: ... + + def foo(self, x: int | str, y: int | str) -> str: + return f"{x}{y}" + + def test_overloads(self) -> None: + with pytest.raises( + CogniteTypeError, + match=re.escape( + "Method tests.tests_unit.test_utils.test_runtime_type_checking.TestOverloads.WithOverload.foo() " + "parameter y=1.0 violates type hint int | str, as float 1.0 not str or int." + ), + ): + self.WithOverload().foo(1, 1.0) + + # Technically should raise a CogniteTypeError, but beartype isn't very good with overloads yet + self.WithOverload().foo(1, "1")