From 38c9c1882f6351dc884b574977a675d3ec48ae3e Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 10:30:17 +0000 Subject: [PATCH 01/38] Pull backend_agnostic class --- src/causalprog/_abc/backend_agnostic.py | 48 +++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/causalprog/_abc/backend_agnostic.py diff --git a/src/causalprog/_abc/backend_agnostic.py b/src/causalprog/_abc/backend_agnostic.py new file mode 100644 index 0000000..8b03fb5 --- /dev/null +++ b/src/causalprog/_abc/backend_agnostic.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar + +Backend = TypeVar("Backend") + + +class BackendAgnostic(ABC, Generic[Backend]): + """A frontend object that must be backend-agnostic.""" + + __slots__ = ("_backend_obj",) + _backend_obj: Backend + + def __getattr__(self, name: str) -> Any: # noqa: ANN401 + """Fallback on the backend object a frontend method isn't found.""" + if name in self._frontend_provides and hasattr(self._backend_obj, name): + return getattr(self._backend_obj, name) + msg = f"{self} has no attribute {name}." + raise AttributeError(msg) + + def __init__(self, *, backend: Backend) -> None: + self._backend_obj = backend + + @property + @abstractmethod + def _frontend_provides(self) -> tuple[str, ...]: + """Methods that an instance of this class must provide.""" + + @property + def _missing_methods(self) -> set[str]: + """Return the names of frontend methods that are missing.""" + return {attr for attr in self._frontend_provides if not hasattr(self, attr)} + + def get_backend(self) -> Backend: + """Access to the backend object.""" + return self._backend_obj + + def validate(self) -> None: + """ + Determine if all expected frontend methods are provided. + + Raises: + AttributeError: If frontend methods are not present. + + """ + if len(self._missing_methods) != 0: + raise AttributeError( + "Missing frontend methods: " + ", ".join(self._missing_methods) + ) From 61eb3312858169aa764f2f6b97976677354f1675 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 10:31:07 +0000 Subject: [PATCH 02/38] Pull backend_agnostic tests --- tests/test__abc/test_backend_agnostic.py | 70 ++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/test__abc/test_backend_agnostic.py diff --git a/tests/test__abc/test_backend_agnostic.py b/tests/test__abc/test_backend_agnostic.py new file mode 100644 index 0000000..6b8586c --- /dev/null +++ b/tests/test__abc/test_backend_agnostic.py @@ -0,0 +1,70 @@ +import re + +import pytest + +from causalprog._abc.backend_agnostic import BackendAgnostic + + +class OneMethodBackend: + def method1(self) -> None: + return + + +class TwoMethodBackend(OneMethodBackend): + def method2(self) -> None: + return + + +class ThreeMethodBackend(TwoMethodBackend): + def method3(self) -> None: + return + + +class BA(BackendAgnostic): + """ + Designed to test the abstract ``BackendAgnostic`` class. + + Instances take ``*methods`` as an argument, which has the effect of setting + ``self.method`` to be a function that returns ``True`` for each ``method`` in + ``*methods``. + """ + + @property + def _frontend_provides(self) -> tuple[str, ...]: + return ( + "method1", + "method2", + ) + + +@pytest.mark.parametrize( + ("backend", "expected_missing"), + [ + pytest.param( + TwoMethodBackend(), + set(), + id="All methods defined.", + ), + pytest.param( + ThreeMethodBackend(), + set(), + id="Additional methods defined.", + ), + pytest.param( + OneMethodBackend(), + {"method2"}, + id="Missing required method.", + ), + ], +) +def test_method_discovery(backend: object, expected_missing: set[str]) -> None: + obj = BA(backend=backend) + assert obj.get_backend() is backend + + assert obj._missing_methods == expected_missing # noqa: SLF001 + if len(expected_missing) != 0: + with pytest.raises( + AttributeError, + match=re.escape("Missing frontend methods: " + ", ".join(expected_missing)), + ): + obj.validate() From 59f6af3c3a738084ac8a5d0857e65f9a349289b5 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 10:34:18 +0000 Subject: [PATCH 03/38] Test positive and negative case --- tests/test__abc/test_backend_agnostic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test__abc/test_backend_agnostic.py b/tests/test__abc/test_backend_agnostic.py index 6b8586c..ff67f85 100644 --- a/tests/test__abc/test_backend_agnostic.py +++ b/tests/test__abc/test_backend_agnostic.py @@ -68,3 +68,5 @@ def test_method_discovery(backend: object, expected_missing: set[str]) -> None: match=re.escape("Missing frontend methods: " + ", ".join(expected_missing)), ): obj.validate() + else: + obj.validate() From 2e91b2a4e13ddab40d53daf9d3217692cb77cf93 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 10:46:57 +0000 Subject: [PATCH 04/38] Tidy docstrings and methods --- src/causalprog/_abc/backend_agnostic.py | 26 ++++++++++++++++-------- tests/test__abc/test_backend_agnostic.py | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/causalprog/_abc/backend_agnostic.py b/src/causalprog/_abc/backend_agnostic.py index 8b03fb5..08aa40d 100644 --- a/src/causalprog/_abc/backend_agnostic.py +++ b/src/causalprog/_abc/backend_agnostic.py @@ -5,13 +5,23 @@ class BackendAgnostic(ABC, Generic[Backend]): - """A frontend object that must be backend-agnostic.""" + """ + A frontend object that must be backend-agnostic. + + ``BackendAgnostic`` is a means of ensuring that an object provides the functionality + and interface that our package expects, irrespective of how this functionality is + actually carried out. An instance of a ``BackendAgnostic`` class stores a reference + to its ``_backend_obj``, and falls back on this object's methods and attributes if + the instance itself does not possess the required attributes. Methods that the + ``BackendAgnostic`` object can also be explicitly defined in the class, and make + calls to the ``_backend_obj`` as necessary. + """ __slots__ = ("_backend_obj",) _backend_obj: Backend def __getattr__(self, name: str) -> Any: # noqa: ANN401 - """Fallback on the backend object a frontend method isn't found.""" + """Fallback on the ``_backend_obj`` a frontend attribute isn't found.""" if name in self._frontend_provides and hasattr(self._backend_obj, name): return getattr(self._backend_obj, name) msg = f"{self} has no attribute {name}." @@ -23,11 +33,11 @@ def __init__(self, *, backend: Backend) -> None: @property @abstractmethod def _frontend_provides(self) -> tuple[str, ...]: - """Methods that an instance of this class must provide.""" + """Names of attributes that an instance of this class must provide.""" @property - def _missing_methods(self) -> set[str]: - """Return the names of frontend methods that are missing.""" + def _missing_attrs(self) -> set[str]: + """Return the names of frontend attributes that are missing.""" return {attr for attr in self._frontend_provides if not hasattr(self, attr)} def get_backend(self) -> Backend: @@ -36,13 +46,13 @@ def get_backend(self) -> Backend: def validate(self) -> None: """ - Determine if all expected frontend methods are provided. + Determine if all expected frontend attributes are provided. Raises: AttributeError: If frontend methods are not present. """ - if len(self._missing_methods) != 0: + if len(self._missing_attrs) != 0: raise AttributeError( - "Missing frontend methods: " + ", ".join(self._missing_methods) + "Missing frontend methods: " + ", ".join(self._missing_attrs) ) diff --git a/tests/test__abc/test_backend_agnostic.py b/tests/test__abc/test_backend_agnostic.py index ff67f85..12d0eb3 100644 --- a/tests/test__abc/test_backend_agnostic.py +++ b/tests/test__abc/test_backend_agnostic.py @@ -61,7 +61,7 @@ def test_method_discovery(backend: object, expected_missing: set[str]) -> None: obj = BA(backend=backend) assert obj.get_backend() is backend - assert obj._missing_methods == expected_missing # noqa: SLF001 + assert obj._missing_attrs == expected_missing # noqa: SLF001 if len(expected_missing) != 0: with pytest.raises( AttributeError, From d2f40d6e5ba4aba0e6ed212aebdcd6ba2c4a9ca4 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:06:00 +0000 Subject: [PATCH 05/38] Pull signature conversions from branch --- src/causalprog/backend/__init__.py | 1 + src/causalprog/backend/_convert_signature.py | 270 +++++++++++++++++++ src/causalprog/backend/_typing.py | 5 + 3 files changed, 276 insertions(+) create mode 100644 src/causalprog/backend/__init__.py create mode 100644 src/causalprog/backend/_convert_signature.py create mode 100644 src/causalprog/backend/_typing.py diff --git a/src/causalprog/backend/__init__.py b/src/causalprog/backend/__init__.py new file mode 100644 index 0000000..e028405 --- /dev/null +++ b/src/causalprog/backend/__init__.py @@ -0,0 +1 @@ +"""Helper functionality for incorporating different backends.""" diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py new file mode 100644 index 0000000..7f87063 --- /dev/null +++ b/src/causalprog/backend/_convert_signature.py @@ -0,0 +1,270 @@ +"""Convert a function signature to a different signature.""" + +import inspect +from collections.abc import Callable +from inspect import Parameter, Signature +from typing import Any + +from ._typing import ParamNameMap, ReturnType, StaticValues + +_VARLENGTH_PARAM_TYPES = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) + + +def _validate_variable_length_parameters( + sig: Signature, +) -> dict[inspect._ParameterKind, str | None]: + """ + Check signature contains at most one variable-length parameter of each kind. + + ``Signature`` objects can contain more than one variable-length parameter, despite + the fact that in practice such a signature cannot exist and be valid Python syntax. + This function checks for such cases, and raises an appropriate error, should they + arise. + + Args: + sig (Signature): Function signature to check for variable-length parameters. + + Returns: + dict[inspect._ParameterKind, str | None]: Mapping of variable-length parameter + kinds to the corresponding parameter name in ``sig``, or to ``None`` if no + parameter of that type exists in the signature. + + """ + named_args: dict[inspect._ParameterKind, str | None] = { + kind: None for kind in _VARLENGTH_PARAM_TYPES + } + for kind in _VARLENGTH_PARAM_TYPES: + possible_parameters = [ + p_name for p_name, p in sig.parameters.items() if p.kind == kind + ] + if len(possible_parameters) > 1: + msg = f"New signature takes more than 1 {kind} argument." + raise ValueError(msg) + if len(possible_parameters) > 0: + named_args[kind] = possible_parameters[0] + return named_args + + +def _signature_can_be_cast( + signature_to_convert: Signature, + new_signature: Signature, + param_name_map: ParamNameMap, + give_static_value: StaticValues, +) -> tuple[ParamNameMap, StaticValues]: + """ + Prepare a signature for conversion to another signature. + + In order to map ``signature_to_convert`` to that of ``new_signature``, the following + assurances are needed: + + - Variable-length parameters in the two signatures are assumed to match (up to name + changes) or be provided explicit defaults. The function will attempt to match + variable-length parameters that are not explicitly matched in the + ``param_name_map``. Note that a signature can have, at most, only one + variable-length positional parameter and one variable-length keyword parameter. + - All parameters WITHOUT DEFAULT VALUES in ``signature_to_convert`` correspond to a + parameter in ``new_signature`` (that may or may not have a default value) OR are + given static values to use, via the ``give_static_value`` argument. + - If ``new_signature`` takes variable-keyword-argument (``**kwargs``), these + arguments are expanded to allow for possible matches to parameters of + ``signature_to_convert``, before passing any remaining parameters after this + unpacked to the variable-keyword-argument of ``signature_to_convert``. + + Args: + signature_to_convert (Signature): Function signature that will be cast to + ``new_signature``. + new_signature (Signature): See the homonymous argument to ``convert_signature``. + param_name_map (ParamNameMap): See the homonymous argument to + ``convert_signature``. + give_static_value (StaticValues): See the homonymous argument to + ``convert_signature``. + + Raises: + ValueError: If the two signatures cannot be cast, even given + the additional information. + + Returns: + ParamNameMap: Mapping of parameter names in the ``signature_to_convert`` to + parameter names in ``new_signature``. Implicit mappings as per function + behaviour are explicitly included in the returned mapping. + StaticValues: Mapping of parameter names in the ``signature_to_convert`` to + static values to assign to these parameters, indicating omission from the + ``new_signature``. Implicit adoption of static values as per function + behaviour are explicitly included in the returned mapping. + + """ + _validate_variable_length_parameters(signature_to_convert) + new_varlength_params = _validate_variable_length_parameters(new_signature) + + param_name_map = dict(param_name_map) + give_static_value = dict(give_static_value) + + new_parameters_accounted_for = set() + + # Check mapping of parameters in old signature to new signature + for p_name, param in signature_to_convert.parameters.items(): + is_explicitly_mapped = p_name in param_name_map + name_is_unchanged = ( + p_name not in param_name_map + and p_name not in param_name_map.values() + and p_name in new_signature.parameters + ) + is_given_static = p_name in give_static_value + can_take_default = param.default is not param.empty + is_varlength_param = param.kind in _VARLENGTH_PARAM_TYPES + mapped_to = None + + if is_explicitly_mapped: + # This parameter is explicitly mapped to another parameter + mapped_to = param_name_map[p_name] + elif name_is_unchanged: + # Parameter is inferred not to change name, having been omitted from the + # explicit mapping. + mapped_to = p_name + param_name_map[p_name] = mapped_to + elif ( + is_varlength_param + and new_varlength_params[param.kind] is not None + and str(new_varlength_params[param.kind]) not in param_name_map.values() + ): + # Automatically map VAR_* parameters to their counterpart, if possible. + mapped_to = str(new_varlength_params[param.kind]) + param_name_map[p_name] = mapped_to + elif is_given_static: + # This parameter is given a static value to use. + continue + elif can_take_default: + # This parameter has a default value in the old signature. + # Since it is not explicitly mapped to another parameter, nor given an + # explicit static value, infer that the default value should be set as the + # static value. + give_static_value[p_name] = param.default + else: + msg = ( + f"Parameter '{p_name}' has no counterpart in new_signature, " + "and does not take a static value." + ) + raise ValueError(msg) + + # Record that any parameter mapped_to in the new_signature is now accounted for, + # to avoid many -> one mappings. + if mapped_to: + if mapped_to in new_parameters_accounted_for: + msg = f"Parameter '{mapped_to}' is mapped to by multiple parameters." + raise ValueError(msg) + # Confirm that variable-length parameters are mapped to variable-length + # parameters (of the same type). + if ( + is_varlength_param + and new_signature.parameters[mapped_to].kind != param.kind + ): + msg = ( + "Variable-length positional/keyword parameters must map to each " + f"other ('{p_name}' is type {param.kind}, but '{mapped_to}' is " + f"type {new_signature.parameters[mapped_to].kind})." + ) + raise ValueError(msg) + + new_parameters_accounted_for.add(param_name_map[p_name]) + + # Confirm all items in new_signature are also accounted for. + unaccounted_new_parameters = ( + set(new_signature.parameters) - new_parameters_accounted_for + ) + if unaccounted_new_parameters: + msg = "Some parameters in new_signature are not used: " + ", ".join( + unaccounted_new_parameters + ) + raise ValueError(msg) + + return param_name_map, give_static_value + + +def convert_signature( + fn: Callable[..., ReturnType], + new_signature: Signature, + old_to_new_names: ParamNameMap, + give_static_value: StaticValues, +) -> Callable[..., ReturnType]: + """ + Convert the call signature of a function ``fn`` to that of ``new_signature``. + + Args: + fn (Callable): Callable object to change the signature of. + new_signature (inspect.Signature): New signature to give to ``fn``. + old_to_new_names (dict[str, str]): Maps the names of parameters in ``fn``s + signature to the corresponding parameter names in the new signature. + Parameter names that do not change can be omitted. Note that parameters that + are to be dropped should be supplied to ``give_static_value`` instead. + give_static_value (dict[str, Any]): Maps names of parameters of ``fn`` to + default values that should be assigned to them. This means that not all + compulsory parameters of ``fn`` have to have a corresponding parameter in + ``new_signature`` - such parameters will use the value assigned to them in + ``give_static_value`` if they are lacking a counterpart parameter in + ``new_signature``. Parameters to ``fn`` that lack a counterpart in + ``new_signature``, and that have default values in ``fn``, will be added + automatically. + + Returns: + Callable: Callable representing ``fn`` with ``new_signature``. + + See Also: + _signature_can_be_cast: Validation method used to check casting is possible. + + """ + fn_signature = inspect.signature(fn) + old_to_new_names, give_static_value = _signature_can_be_cast( + fn_signature, new_signature, old_to_new_names, give_static_value + ) + new_to_old_names = {value: key for key, value in old_to_new_names.items()} + + fn_varlength_params = _validate_variable_length_parameters(fn_signature) + fn_vargs_param = fn_varlength_params[Parameter.VAR_POSITIONAL] + fn_kwargs_param = fn_varlength_params[Parameter.VAR_KEYWORD] + + new_varlength_params = _validate_variable_length_parameters(new_signature) + new_kwargs_param = new_varlength_params[Parameter.VAR_KEYWORD] + + fn_posix_args = [ + p_name + for p_name, param in fn_signature.parameters.items() + if param.kind <= param.POSITIONAL_OR_KEYWORD + ] + + # If fn's VAR_KEYWORD parameter is dropped from the new_signature, + # it must have been given a default value to use. We need to expand + # these values now so that they get passed correctly as keyword arguments. + if fn_kwargs_param and fn_kwargs_param in give_static_value: + static_kwargs = give_static_value.pop(fn_kwargs_param) + give_static_value = dict(give_static_value, **static_kwargs) + + def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType: + bound = new_signature.bind(*args, **kwargs) + bound.apply_defaults() + + all_args_received = bound.arguments + kwargs_to_pass_on = ( + all_args_received.pop(new_kwargs_param, {}) if new_kwargs_param else {} + ) + # Maps the name of a parameter to fn to the value that should be supplied, + # as obtained from the arguments provided to this function. + # Calling dict with give_static_value FIRST is important, as defaults will get + # overwritten by any passed arguments! + fn_kwargs = dict( + give_static_value, + **{ + new_to_old_names[key]: value for key, value in all_args_received.items() + }, + **kwargs_to_pass_on, + ) + # We can supply all arguments EXCEPT the variable-positional and positional-only + # arguments as keyword args. + # Positional-only arguments have to come first, followed by the + # variable-positional parameters. + fn_args = [fn_kwargs.pop(p_name) for p_name in fn_posix_args] + if fn_vargs_param: + fn_args.extend(fn_kwargs.pop(fn_vargs_param, [])) + # Now we can call fn + return fn(*fn_args, **fn_kwargs) + + return fn_with_new_signature diff --git a/src/causalprog/backend/_typing.py b/src/causalprog/backend/_typing.py new file mode 100644 index 0000000..e7ab7fb --- /dev/null +++ b/src/causalprog/backend/_typing.py @@ -0,0 +1,5 @@ +from typing import Any, TypeAlias, TypeVar + +ReturnType = TypeVar("ReturnType") +ParamNameMap: TypeAlias = dict[str, str] +StaticValues: TypeAlias = dict[str, Any] From 9b3728a81e81a3b485b4b5eabd363cbf8f936ad4 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:06:45 +0000 Subject: [PATCH 06/38] Pull tests for converting signatures --- tests/test_backend/test_convert_signature.py | 375 +++++++++++++++++++ 1 file changed, 375 insertions(+) create mode 100644 tests/test_backend/test_convert_signature.py diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py new file mode 100644 index 0000000..c418545 --- /dev/null +++ b/tests/test_backend/test_convert_signature.py @@ -0,0 +1,375 @@ +import re +from collections.abc import Iterable +from inspect import Parameter, Signature, signature +from typing import Any + +import pytest + +from causalprog.backend._convert_signature import ( + _signature_can_be_cast, + _validate_variable_length_parameters, + convert_signature, +) +from causalprog.backend._typing import ParamNameMap, StaticValues + + +def general_function( + posix, /, posix_def="posix_def", *vargs, kwo, kwo_def="kwo_def", **kwargs +): + return posix, posix_def, vargs, kwo, kwo_def, kwargs + + +@pytest.mark.parametrize( + ("signature", "expected"), + [ + pytest.param( + Signature( + ( + Parameter("vargs1", Parameter.VAR_POSITIONAL), + Parameter("vargs2", Parameter.VAR_POSITIONAL), + ) + ), + ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), + id="Two variable-length positional arguments.", + ), + pytest.param( + Signature( + ( + Parameter("kwargs1", Parameter.VAR_KEYWORD), + Parameter("kwargs2", Parameter.VAR_KEYWORD), + ) + ), + ValueError("New signature takes more than 1 VAR_KEYWORD argument."), + id="Two variable-length keyword arguments.", + ), + pytest.param( + signature(general_function), + {Parameter.VAR_POSITIONAL: "vargs", Parameter.VAR_KEYWORD: "kwargs"}, + id="Valid, but complex, signature.", + ), + pytest.param( + Signature( + ( + Parameter("arg1", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("arg2", Parameter.POSITIONAL_OR_KEYWORD, default=1), + Parameter("vargs1", Parameter.VAR_POSITIONAL), + Parameter("vargs2", Parameter.VAR_POSITIONAL), + Parameter("kwargs1", Parameter.VAR_KEYWORD), + ) + ), + ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), + id="Two variable-length positional arguments, mixed with others.", + ), + ], +) +def test_validate_variable_length_parameters( + signature: Signature, expected: Exception | dict +): + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): + _validate_variable_length_parameters(signature) + else: + returned_names = _validate_variable_length_parameters(signature) + + assert returned_names == expected + + +@pytest.mark.parametrize( + ( + "signature_to_convert", + "new_signature", + "param_name_map", + "give_static_value", + "expected_output", + ), + [ + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {}, + {}, + ValueError( + "Parameter 'b' has no counterpart in new_signature, " + "and does not take a static value." + ), + id="Parameter not matched.", + ), + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {"a": "a", "b": "a"}, + {}, + ValueError("Parameter 'a' is mapped to by multiple parameters."), + id="Two arguments mapped to a single parameter.", + ), + pytest.param( + Signature( + [ + Parameter("vargs", Parameter.VAR_POSITIONAL), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {"vargs": "a"}, + {}, + ValueError( + "Variable-length positional/keyword parameters must map to each other " + "('vargs' is type VAR_POSITIONAL, but 'a' is type POSITIONAL_ONLY)." + ), + id="Map *args to positional argument.", + ), + pytest.param( + Signature( + [ + Parameter("vargs", Parameter.VAR_POSITIONAL), + ] + ), + Signature( + [ + Parameter("kwarg", Parameter.VAR_KEYWORD), + ] + ), + {"vargs": "kwarg"}, + {}, + ValueError( + "Variable-length positional/keyword parameters must map to each other " + "('vargs' is type VAR_POSITIONAL, but 'kwarg' is type VAR_KEYWORD)." + ), + id="Map *args to **kwargs.", + ), + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + {}, + {}, + ValueError("Some parameters in new_signature are not used: b"), + id="new_signature contains extra parameters.", + ), + pytest.param( + signature(general_function), + signature(general_function), + {}, + {}, + ({key: key for key in signature(general_function).parameters}, {}), + id="Can cast to yourself.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), + Signature([Parameter("a", Parameter.KEYWORD_ONLY)]), + {}, + {}, + ({"a": "a"}, {}), + id="Infer identically named parameter (even with type change)", + ), + pytest.param( + Signature([Parameter("args", Parameter.VAR_POSITIONAL)]), + Signature([Parameter("new_args", Parameter.VAR_POSITIONAL)]), + {}, + {}, + ({"args": "new_args"}, {}), + id="Infer VAR_POSITIONAL matching.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), + Signature([]), + {}, + {"a": 10}, + ({}, {"a": 10}), + id="Assign static value to argument without default.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY, default=10)]), + Signature([]), + {}, + {}, + ({}, {"a": 10}), + id="Infer static value from argument default.", + ), + ], +) +def test_signature_can_be_cast( + signature_to_convert: Signature, + new_signature: Signature, + param_name_map: ParamNameMap, + give_static_value: StaticValues, + expected_output: Exception | tuple[str | None, ParamNameMap, StaticValues], +) -> None: + if isinstance(expected_output, Exception): + with pytest.raises( + type(expected_output), match=re.escape(str(expected_output)) + ): + _signature_can_be_cast( + signature_to_convert, + new_signature, + param_name_map, + give_static_value, + ) + else: + computed_output = _signature_can_be_cast( + signature_to_convert, + new_signature, + param_name_map, + give_static_value, + ) + + assert computed_output == expected_output + + +_kwargs_static_value = {"some": "keyword-arguments"} + + +@pytest.mark.parametrize( + ( + "posix_for_new_call", + "keyword_for_new_call", + "expected_assignments", + ), + [ + pytest.param( + [1, 2], + {"kwo_n": 3, "kwo_def_n": 4}, + { + "posix": 3, + "posix_def": 4, + "vargs": (), + "kwo": 1, + "kwo_def": 2, + "kwargs": _kwargs_static_value, + }, + id="No vargs supplied.", + ), + pytest.param( + [1, 2, 10, 11, 12], + {"kwo_n": 3, "kwo_def_n": 4}, + { + "posix": 3, + "posix_def": 4, + "vargs": (10, 11, 12), + "kwo": 1, + "kwo_def": 2, + "kwargs": _kwargs_static_value, + }, + id="Supply vargs.", + ), + pytest.param( + [1], + {"kwo_n": 3}, + { + "posix": 3, + "posix_def": "default_for_kwo_def_n", + "vargs": (), + "kwo": 1, + "kwo_def": "default_for_posix_def_n", + "kwargs": _kwargs_static_value, + }, + id="New default values respected.", + ), + pytest.param( + [1], + {"kwo_n": 3, "extra_kwarg": "not allowed"}, + TypeError("got an unexpected keyword argument 'extra_kwarg'"), + id="kwargs not allowed in new signature.", + ), + pytest.param( + [1, 2], + {"kwo_n": 3, "posix_def_n": 2}, + TypeError("multiple values for argument 'posix_def_n'"), + id="Multiple values for new parameter.", + ), + ], +) +def test_convert_signature( + posix_for_new_call: Iterable[Any], + keyword_for_new_call: dict[str, Any], + expected_assignments: dict[str, Any] | Exception, +) -> None: + """ + To ease the burden of setting up and parametrising this test, + we will always use the general_function signature as the target and source + signature. + + However, the target signature will swap the roles of the positional and keyword + parameters, essentially mapping: + + ``posix, posix_def, *vargs, kwo, kwo_def, **kwargs`` + + to + + ``kwo_n, kwo_def_n, *vargs_n, posix_n, posix_def_n``. + + ``give_static_value`` will give kwargs a default value. + + We can then make calls to this new signature, and since ``general_function`` returns + the arguments it received, we can validate that correct passing of arguments occurs. + """ + param_name_map = { + "posix": "kwo_n", + "posix_def": "kwo_def_n", + "kwo": "posix_n", + "kwo_def": "posix_def_n", + } + give_static_value = {"kwargs": _kwargs_static_value} + new_signature = Signature( + [ + Parameter("posix_n", Parameter.POSITIONAL_ONLY), + Parameter( + "posix_def_n", + Parameter.POSITIONAL_OR_KEYWORD, + default="default_for_posix_def_n", + ), + Parameter("vargs_n", Parameter.VAR_POSITIONAL), + Parameter("kwo_n", Parameter.KEYWORD_ONLY), + Parameter( + "kwo_def_n", Parameter.KEYWORD_ONLY, default="default_for_kwo_def_n" + ), + ] + ) + new_function = convert_signature( + general_function, new_signature, param_name_map, give_static_value + ) + + if isinstance(expected_assignments, Exception): + with pytest.raises( + type(expected_assignments), match=re.escape(str(expected_assignments)) + ): + new_function(*posix_for_new_call, **keyword_for_new_call) + else: + posix, posix_def, vargs, kwo, kwo_def, kwargs = new_function( + *posix_for_new_call, **keyword_for_new_call + ) + assert posix == expected_assignments["posix"] + assert posix_def == expected_assignments["posix_def"] + assert vargs == expected_assignments["vargs"] + assert kwo == expected_assignments["kwo"] + assert kwo_def == expected_assignments["kwo_def"] + assert kwargs == expected_assignments["kwargs"] From 1958d98cb65dd01a0b43611326174d03c98d2b6a Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:09:03 +0000 Subject: [PATCH 07/38] Hide hidden variable import behind typehint --- src/causalprog/backend/_convert_signature.py | 8 ++++---- src/causalprog/backend/_typing.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 7f87063..d0a2616 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -5,14 +5,14 @@ from inspect import Parameter, Signature from typing import Any -from ._typing import ParamNameMap, ReturnType, StaticValues +from ._typing import ParamKind, ParamNameMap, ReturnType, StaticValues _VARLENGTH_PARAM_TYPES = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) def _validate_variable_length_parameters( sig: Signature, -) -> dict[inspect._ParameterKind, str | None]: +) -> dict[ParamKind, str | None]: """ Check signature contains at most one variable-length parameter of each kind. @@ -25,12 +25,12 @@ def _validate_variable_length_parameters( sig (Signature): Function signature to check for variable-length parameters. Returns: - dict[inspect._ParameterKind, str | None]: Mapping of variable-length parameter + dict[ParamKind, str | None]: Mapping of variable-length parameter kinds to the corresponding parameter name in ``sig``, or to ``None`` if no parameter of that type exists in the signature. """ - named_args: dict[inspect._ParameterKind, str | None] = { + named_args: dict[ParamKind, str | None] = { kind: None for kind in _VARLENGTH_PARAM_TYPES } for kind in _VARLENGTH_PARAM_TYPES: diff --git a/src/causalprog/backend/_typing.py b/src/causalprog/backend/_typing.py index e7ab7fb..c214967 100644 --- a/src/causalprog/backend/_typing.py +++ b/src/causalprog/backend/_typing.py @@ -1,5 +1,7 @@ +from inspect import _ParameterKind from typing import Any, TypeAlias, TypeVar ReturnType = TypeVar("ReturnType") ParamNameMap: TypeAlias = dict[str, str] +ParamKind: TypeAlias = _ParameterKind StaticValues: TypeAlias = dict[str, Any] From 89d08ab7cd194b3c33e9dfa42b54d56ff8278c0b Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:12:16 +0000 Subject: [PATCH 08/38] Tidy vargs and kwargs checker function --- src/causalprog/backend/_convert_signature.py | 25 ++++++++++++-------- tests/test_backend/test_convert_signature.py | 6 ++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index d0a2616..4c76e9e 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -10,16 +10,21 @@ _VARLENGTH_PARAM_TYPES = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) -def _validate_variable_length_parameters( +def _check_variable_length_params( sig: Signature, ) -> dict[ParamKind, str | None]: """ - Check signature contains at most one variable-length parameter of each kind. + Return the names of variable-length parameters in a signature. - ``Signature`` objects can contain more than one variable-length parameter, despite - the fact that in practice such a signature cannot exist and be valid Python syntax. - This function checks for such cases, and raises an appropriate error, should they - arise. + There are two types of variable-length parameters; positional (VAR_POSITIONAL) which + are typically denoted ``*args`` or ``*vargs``, and keyword (VAR_KEYWORD) which are + typically denoted ``**kwargs``. + + ``Signature`` objects can contain more than one variable-length parameter of each + kind, despite the fact that in practice such a signature cannot exist and be valid + Python syntax. This function checks for such cases, and raises an appropriate error, + should they arise. Otherwise, it simply identifies the parameters in ``sig`` which + correspond to these two variable-length parameter kinds. Args: sig (Signature): Function signature to check for variable-length parameters. @@ -93,8 +98,8 @@ def _signature_can_be_cast( behaviour are explicitly included in the returned mapping. """ - _validate_variable_length_parameters(signature_to_convert) - new_varlength_params = _validate_variable_length_parameters(new_signature) + _check_variable_length_params(signature_to_convert) + new_varlength_params = _check_variable_length_params(new_signature) param_name_map = dict(param_name_map) give_static_value = dict(give_static_value) @@ -218,11 +223,11 @@ def convert_signature( ) new_to_old_names = {value: key for key, value in old_to_new_names.items()} - fn_varlength_params = _validate_variable_length_parameters(fn_signature) + fn_varlength_params = _check_variable_length_params(fn_signature) fn_vargs_param = fn_varlength_params[Parameter.VAR_POSITIONAL] fn_kwargs_param = fn_varlength_params[Parameter.VAR_KEYWORD] - new_varlength_params = _validate_variable_length_parameters(new_signature) + new_varlength_params = _check_variable_length_params(new_signature) new_kwargs_param = new_varlength_params[Parameter.VAR_KEYWORD] fn_posix_args = [ diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index c418545..f4463e5 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -6,8 +6,8 @@ import pytest from causalprog.backend._convert_signature import ( + _check_variable_length_params, _signature_can_be_cast, - _validate_variable_length_parameters, convert_signature, ) from causalprog.backend._typing import ParamNameMap, StaticValues @@ -67,9 +67,9 @@ def test_validate_variable_length_parameters( ): if isinstance(expected, Exception): with pytest.raises(type(expected), match=re.escape(str(expected))): - _validate_variable_length_parameters(signature) + _check_variable_length_params(signature) else: - returned_names = _validate_variable_length_parameters(signature) + returned_names = _check_variable_length_params(signature) assert returned_names == expected From aa7d06fddab996bbddc30c7862633620643e1030 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:18:52 +0000 Subject: [PATCH 09/38] Refactor _check tests --- tests/test_backend/conftest.py | 21 ++++++ .../test_check_variable_length_parameters.py | 64 +++++++++++++++++++ tests/test_backend/test_convert_signature.py | 56 ---------------- 3 files changed, 85 insertions(+), 56 deletions(-) create mode 100644 tests/test_backend/conftest.py create mode 100644 tests/test_backend/test_check_variable_length_parameters.py diff --git a/tests/test_backend/conftest.py b/tests/test_backend/conftest.py new file mode 100644 index 0000000..00734e4 --- /dev/null +++ b/tests/test_backend/conftest.py @@ -0,0 +1,21 @@ +from collections.abc import Callable +from inspect import Signature, signature + +import pytest + + +@pytest.fixture +def general_function() -> Callable: + def _general_function( + posix, /, posix_def="posix_def", *vargs, kwo, kwo_def="kwo_def", **kwargs + ): + """Return the provided arguments.""" + return posix, posix_def, vargs, kwo, kwo_def, kwargs + + return _general_function + + +@pytest.fixture +def general_function_signature(general_function: Callable) -> Signature: + """Signature of the ``general_function`` callable.""" + return signature(general_function) diff --git a/tests/test_backend/test_check_variable_length_parameters.py b/tests/test_backend/test_check_variable_length_parameters.py new file mode 100644 index 0000000..cf7093f --- /dev/null +++ b/tests/test_backend/test_check_variable_length_parameters.py @@ -0,0 +1,64 @@ +import re +from inspect import Parameter, Signature + +import pytest + +from causalprog.backend._convert_signature import _check_variable_length_params + + +@pytest.mark.parametrize( + ("signature", "expected"), + [ + pytest.param( + Signature( + ( + Parameter("vargs1", Parameter.VAR_POSITIONAL), + Parameter("vargs2", Parameter.VAR_POSITIONAL), + ) + ), + ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), + id="Two variable-length positional arguments.", + ), + pytest.param( + Signature( + ( + Parameter("kwargs1", Parameter.VAR_KEYWORD), + Parameter("kwargs2", Parameter.VAR_KEYWORD), + ) + ), + ValueError("New signature takes more than 1 VAR_KEYWORD argument."), + id="Two variable-length keyword arguments.", + ), + pytest.param( + "general_function_signature", + {Parameter.VAR_POSITIONAL: "vargs", Parameter.VAR_KEYWORD: "kwargs"}, + id="Valid, but complex, signature.", + ), + pytest.param( + Signature( + ( + Parameter("arg1", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("arg2", Parameter.POSITIONAL_OR_KEYWORD, default=1), + Parameter("vargs1", Parameter.VAR_POSITIONAL), + Parameter("vargs2", Parameter.VAR_POSITIONAL), + Parameter("kwargs1", Parameter.VAR_KEYWORD), + ) + ), + ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), + id="Two variable-length positional arguments, mixed with others.", + ), + ], +) +def test_check_variable_length_parameters( + signature: Signature, expected: Exception | dict, request +): + if isinstance(signature, str): + signature = request.getfixturevalue(signature) + + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): + _check_variable_length_params(signature) + else: + returned_names = _check_variable_length_params(signature) + + assert returned_names == expected diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index f4463e5..826a259 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -6,7 +6,6 @@ import pytest from causalprog.backend._convert_signature import ( - _check_variable_length_params, _signature_can_be_cast, convert_signature, ) @@ -19,61 +18,6 @@ def general_function( return posix, posix_def, vargs, kwo, kwo_def, kwargs -@pytest.mark.parametrize( - ("signature", "expected"), - [ - pytest.param( - Signature( - ( - Parameter("vargs1", Parameter.VAR_POSITIONAL), - Parameter("vargs2", Parameter.VAR_POSITIONAL), - ) - ), - ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), - id="Two variable-length positional arguments.", - ), - pytest.param( - Signature( - ( - Parameter("kwargs1", Parameter.VAR_KEYWORD), - Parameter("kwargs2", Parameter.VAR_KEYWORD), - ) - ), - ValueError("New signature takes more than 1 VAR_KEYWORD argument."), - id="Two variable-length keyword arguments.", - ), - pytest.param( - signature(general_function), - {Parameter.VAR_POSITIONAL: "vargs", Parameter.VAR_KEYWORD: "kwargs"}, - id="Valid, but complex, signature.", - ), - pytest.param( - Signature( - ( - Parameter("arg1", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("arg2", Parameter.POSITIONAL_OR_KEYWORD, default=1), - Parameter("vargs1", Parameter.VAR_POSITIONAL), - Parameter("vargs2", Parameter.VAR_POSITIONAL), - Parameter("kwargs1", Parameter.VAR_KEYWORD), - ) - ), - ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), - id="Two variable-length positional arguments, mixed with others.", - ), - ], -) -def test_validate_variable_length_parameters( - signature: Signature, expected: Exception | dict -): - if isinstance(expected, Exception): - with pytest.raises(type(expected), match=re.escape(str(expected))): - _check_variable_length_params(signature) - else: - returned_names = _check_variable_length_params(signature) - - assert returned_names == expected - - @pytest.mark.parametrize( ( "signature_to_convert", From 1fd75f09ab1ea0635d41038142b1c8ed13d51cc7 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:41:28 +0000 Subject: [PATCH 10/38] Tidy convert_signature docstring --- src/causalprog/backend/_convert_signature.py | 59 ++++++++++++++++---- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 4c76e9e..15c6586 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -194,27 +194,64 @@ def convert_signature( """ Convert the call signature of a function ``fn`` to that of ``new_signature``. + This function effectively allows ``fn`` to be called with ``new_signature``. It + returns a new ``Callable`` that uses the ``new_signature``, and returns the result + of ``fn`` after translating the ``new_signature`` back into that of ``fn`` and + making an appropriate call. + + Converting signatures into each other is, in general, not possible. However under + certain assumptions and conventions, it can be done. To that end, the following + assumptions are made about ``fn`` and ``new_signature``: + + 1. All parameters to ``fn`` are either; + 1. mapped to one non-variable-length parameter of ``new_signature``, or + 2. provided with a static value to be used in all calls. + 2. If ``fn`` takes a ``VAR_POSITIONAL`` parameter ``*args``, then either + 1. ``new_signature`` must also take a ``VAR_POSITIONAL`` parameter, and this + must map to identically to ``*args``, + 2. ``*args`` is provided with a static value to be used in all calls, and + ``new_signature`` must not take ``VAR_POSITIONAL`` arguments. + 3. If ``fn`` takes a ``VAR_KEYWORD`` parameter ``**kwargs``, then either + 1. ``new_signature`` must also take a ``VAR_KEYWORD`` parameter, and this + must map to identically to ``**kwargs``, + 2. ``**kwargs`` is provided with a static value to be used in all calls, and + ``new_signature`` must not take ``VAR_KEYWORD`` arguments. + + Mapping of parameters is done by name, from the signature of ``fn`` to + ``new_signature``, in the ``old_to_new_names`` argument. + + 4. If a parameter does not change name between the two signatures, it can be omitted + from this mapping and it will be inferred. Note that such a parameter may still + change kind, or adopt a new default value, in the ``new_signature``. + + Parameters can also be "dropped" from ``fn``'s signature in ``new_signature``, by + assigning them static values to be used in all cases. Such static values are given + in the ``give_static_value`` mapping, which maps (names of) parameters of ``fn`` to + a fixed value to be used for that parameter. This means that these parameters do not + need to be mapped to a parameter in ``new_signature``. + + 5. Parameters that have default values in ``fn``, and which are not mapped to a + parameter of ``new_signature``, will adopt their default value as a static value. + Args: fn (Callable): Callable object to change the signature of. new_signature (inspect.Signature): New signature to give to ``fn``. old_to_new_names (dict[str, str]): Maps the names of parameters in ``fn``s - signature to the corresponding parameter names in the new signature. - Parameter names that do not change can be omitted. Note that parameters that - are to be dropped should be supplied to ``give_static_value`` instead. + signature to the corresponding parameter names in ``new_signature``. give_static_value (dict[str, Any]): Maps names of parameters of ``fn`` to - default values that should be assigned to them. This means that not all - compulsory parameters of ``fn`` have to have a corresponding parameter in - ``new_signature`` - such parameters will use the value assigned to them in - ``give_static_value`` if they are lacking a counterpart parameter in - ``new_signature``. Parameters to ``fn`` that lack a counterpart in - ``new_signature``, and that have default values in ``fn``, will be added - automatically. + static values that should be assigned to them. + + Raises: + ValueError: If ``fn``'s signature cannot be cast to ``new_signature``, given the + information provided. Returns: Callable: Callable representing ``fn`` with ``new_signature``. See Also: - _signature_can_be_cast: Validation method used to check casting is possible. + _check_variable_length_params: Validation of number of variable-length + parameters. + _signature_can_be_cast: Validation method used to check signatures can be cast. """ fn_signature = inspect.signature(fn) From 44a8504ab082e7cfd856ecb452646f02819cdedd Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:46:28 +0000 Subject: [PATCH 11/38] Parameter naming and docstrings for _signature_can_be_cast --- src/causalprog/backend/_convert_signature.py | 52 ++++++++------------ tests/test_backend/test_convert_signature.py | 8 +-- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 15c6586..81d9738 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -53,33 +53,20 @@ def _check_variable_length_params( def _signature_can_be_cast( signature_to_convert: Signature, new_signature: Signature, - param_name_map: ParamNameMap, + old_to_new_names: ParamNameMap, give_static_value: StaticValues, ) -> tuple[ParamNameMap, StaticValues]: """ Prepare a signature for conversion to another signature. - In order to map ``signature_to_convert`` to that of ``new_signature``, the following - assurances are needed: - - - Variable-length parameters in the two signatures are assumed to match (up to name - changes) or be provided explicit defaults. The function will attempt to match - variable-length parameters that are not explicitly matched in the - ``param_name_map``. Note that a signature can have, at most, only one - variable-length positional parameter and one variable-length keyword parameter. - - All parameters WITHOUT DEFAULT VALUES in ``signature_to_convert`` correspond to a - parameter in ``new_signature`` (that may or may not have a default value) OR are - given static values to use, via the ``give_static_value`` argument. - - If ``new_signature`` takes variable-keyword-argument (``**kwargs``), these - arguments are expanded to allow for possible matches to parameters of - ``signature_to_convert``, before passing any remaining parameters after this - unpacked to the variable-keyword-argument of ``signature_to_convert``. + This is a helper that handles the validation detailed in ``convert_signature``. + See the docstring of ``convert_signature`` for more details. Args: signature_to_convert (Signature): Function signature that will be cast to ``new_signature``. new_signature (Signature): See the homonymous argument to ``convert_signature``. - param_name_map (ParamNameMap): See the homonymous argument to + old_to_new_names (ParamNameMap): See the homonymous argument to ``convert_signature``. give_static_value (StaticValues): See the homonymous argument to ``convert_signature``. @@ -90,28 +77,31 @@ def _signature_can_be_cast( Returns: ParamNameMap: Mapping of parameter names in the ``signature_to_convert`` to - parameter names in ``new_signature``. Implicit mappings as per function - behaviour are explicitly included in the returned mapping. + parameter names in ``new_signature``. Implicit mappings as per behaviour of + ``convert_signature`` are explicitly included in the returned mapping. StaticValues: Mapping of parameter names in the ``signature_to_convert`` to static values to assign to these parameters, indicating omission from the - ``new_signature``. Implicit adoption of static values as per function - behaviour are explicitly included in the returned mapping. + ``new_signature``. Implicit adoption of static values as per behaviour of + ``convert_signature`` are explicitly included in the returned mapping. + + See Also: + convert_signature: Function for which setup is being performed. """ _check_variable_length_params(signature_to_convert) new_varlength_params = _check_variable_length_params(new_signature) - param_name_map = dict(param_name_map) + old_to_new_names = dict(old_to_new_names) give_static_value = dict(give_static_value) new_parameters_accounted_for = set() # Check mapping of parameters in old signature to new signature for p_name, param in signature_to_convert.parameters.items(): - is_explicitly_mapped = p_name in param_name_map + is_explicitly_mapped = p_name in old_to_new_names name_is_unchanged = ( - p_name not in param_name_map - and p_name not in param_name_map.values() + p_name not in old_to_new_names + and p_name not in old_to_new_names.values() and p_name in new_signature.parameters ) is_given_static = p_name in give_static_value @@ -121,20 +111,20 @@ def _signature_can_be_cast( if is_explicitly_mapped: # This parameter is explicitly mapped to another parameter - mapped_to = param_name_map[p_name] + mapped_to = old_to_new_names[p_name] elif name_is_unchanged: # Parameter is inferred not to change name, having been omitted from the # explicit mapping. mapped_to = p_name - param_name_map[p_name] = mapped_to + old_to_new_names[p_name] = mapped_to elif ( is_varlength_param and new_varlength_params[param.kind] is not None - and str(new_varlength_params[param.kind]) not in param_name_map.values() + and str(new_varlength_params[param.kind]) not in old_to_new_names.values() ): # Automatically map VAR_* parameters to their counterpart, if possible. mapped_to = str(new_varlength_params[param.kind]) - param_name_map[p_name] = mapped_to + old_to_new_names[p_name] = mapped_to elif is_given_static: # This parameter is given a static value to use. continue @@ -170,7 +160,7 @@ def _signature_can_be_cast( ) raise ValueError(msg) - new_parameters_accounted_for.add(param_name_map[p_name]) + new_parameters_accounted_for.add(old_to_new_names[p_name]) # Confirm all items in new_signature are also accounted for. unaccounted_new_parameters = ( @@ -182,7 +172,7 @@ def _signature_can_be_cast( ) raise ValueError(msg) - return param_name_map, give_static_value + return old_to_new_names, give_static_value def convert_signature( diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index 826a259..76fbb0d 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -22,7 +22,7 @@ def general_function( ( "signature_to_convert", "new_signature", - "param_name_map", + "old_to_new_names", "give_static_value", "expected_output", ), @@ -164,7 +164,7 @@ def general_function( def test_signature_can_be_cast( signature_to_convert: Signature, new_signature: Signature, - param_name_map: ParamNameMap, + old_to_new_names: ParamNameMap, give_static_value: StaticValues, expected_output: Exception | tuple[str | None, ParamNameMap, StaticValues], ) -> None: @@ -175,14 +175,14 @@ def test_signature_can_be_cast( _signature_can_be_cast( signature_to_convert, new_signature, - param_name_map, + old_to_new_names, give_static_value, ) else: computed_output = _signature_can_be_cast( signature_to_convert, new_signature, - param_name_map, + old_to_new_names, give_static_value, ) From dc283cf0c608410d5deea8c6091e448e8a9e4513 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:51:24 +0000 Subject: [PATCH 12/38] Refactor _signature_can_be_cast tests --- tests/test_backend/test_convert_signature.py | 179 +--------------- .../test_signature_can_be_cast.py | 194 ++++++++++++++++++ 2 files changed, 196 insertions(+), 177 deletions(-) create mode 100644 tests/test_backend/test_signature_can_be_cast.py diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index 76fbb0d..d6ee1cf 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -1,15 +1,11 @@ import re from collections.abc import Iterable -from inspect import Parameter, Signature, signature +from inspect import Parameter, Signature from typing import Any import pytest -from causalprog.backend._convert_signature import ( - _signature_can_be_cast, - convert_signature, -) -from causalprog.backend._typing import ParamNameMap, StaticValues +from causalprog.backend._convert_signature import convert_signature def general_function( @@ -18,177 +14,6 @@ def general_function( return posix, posix_def, vargs, kwo, kwo_def, kwargs -@pytest.mark.parametrize( - ( - "signature_to_convert", - "new_signature", - "old_to_new_names", - "give_static_value", - "expected_output", - ), - [ - pytest.param( - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - Parameter("b", Parameter.POSITIONAL_ONLY), - ] - ), - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - ] - ), - {}, - {}, - ValueError( - "Parameter 'b' has no counterpart in new_signature, " - "and does not take a static value." - ), - id="Parameter not matched.", - ), - pytest.param( - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - Parameter("b", Parameter.POSITIONAL_ONLY), - ] - ), - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - ] - ), - {"a": "a", "b": "a"}, - {}, - ValueError("Parameter 'a' is mapped to by multiple parameters."), - id="Two arguments mapped to a single parameter.", - ), - pytest.param( - Signature( - [ - Parameter("vargs", Parameter.VAR_POSITIONAL), - ] - ), - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - ] - ), - {"vargs": "a"}, - {}, - ValueError( - "Variable-length positional/keyword parameters must map to each other " - "('vargs' is type VAR_POSITIONAL, but 'a' is type POSITIONAL_ONLY)." - ), - id="Map *args to positional argument.", - ), - pytest.param( - Signature( - [ - Parameter("vargs", Parameter.VAR_POSITIONAL), - ] - ), - Signature( - [ - Parameter("kwarg", Parameter.VAR_KEYWORD), - ] - ), - {"vargs": "kwarg"}, - {}, - ValueError( - "Variable-length positional/keyword parameters must map to each other " - "('vargs' is type VAR_POSITIONAL, but 'kwarg' is type VAR_KEYWORD)." - ), - id="Map *args to **kwargs.", - ), - pytest.param( - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - ] - ), - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - Parameter("b", Parameter.POSITIONAL_ONLY), - ] - ), - {}, - {}, - ValueError("Some parameters in new_signature are not used: b"), - id="new_signature contains extra parameters.", - ), - pytest.param( - signature(general_function), - signature(general_function), - {}, - {}, - ({key: key for key in signature(general_function).parameters}, {}), - id="Can cast to yourself.", - ), - pytest.param( - Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), - Signature([Parameter("a", Parameter.KEYWORD_ONLY)]), - {}, - {}, - ({"a": "a"}, {}), - id="Infer identically named parameter (even with type change)", - ), - pytest.param( - Signature([Parameter("args", Parameter.VAR_POSITIONAL)]), - Signature([Parameter("new_args", Parameter.VAR_POSITIONAL)]), - {}, - {}, - ({"args": "new_args"}, {}), - id="Infer VAR_POSITIONAL matching.", - ), - pytest.param( - Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), - Signature([]), - {}, - {"a": 10}, - ({}, {"a": 10}), - id="Assign static value to argument without default.", - ), - pytest.param( - Signature([Parameter("a", Parameter.POSITIONAL_ONLY, default=10)]), - Signature([]), - {}, - {}, - ({}, {"a": 10}), - id="Infer static value from argument default.", - ), - ], -) -def test_signature_can_be_cast( - signature_to_convert: Signature, - new_signature: Signature, - old_to_new_names: ParamNameMap, - give_static_value: StaticValues, - expected_output: Exception | tuple[str | None, ParamNameMap, StaticValues], -) -> None: - if isinstance(expected_output, Exception): - with pytest.raises( - type(expected_output), match=re.escape(str(expected_output)) - ): - _signature_can_be_cast( - signature_to_convert, - new_signature, - old_to_new_names, - give_static_value, - ) - else: - computed_output = _signature_can_be_cast( - signature_to_convert, - new_signature, - old_to_new_names, - give_static_value, - ) - - assert computed_output == expected_output - - _kwargs_static_value = {"some": "keyword-arguments"} diff --git a/tests/test_backend/test_signature_can_be_cast.py b/tests/test_backend/test_signature_can_be_cast.py new file mode 100644 index 0000000..36f0f53 --- /dev/null +++ b/tests/test_backend/test_signature_can_be_cast.py @@ -0,0 +1,194 @@ +import re +from inspect import Parameter, Signature + +import pytest + +from causalprog.backend._convert_signature import _signature_can_be_cast +from causalprog.backend._typing import ParamNameMap, StaticValues + + +@pytest.mark.parametrize( + ( + "signature_to_convert", + "new_signature", + "old_to_new_names", + "give_static_value", + "expected_output", + ), + [ + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {}, + {}, + ValueError( + "Parameter 'b' has no counterpart in new_signature, " + "and does not take a static value." + ), + id="Parameter not matched.", + ), + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {"a": "a", "b": "a"}, + {}, + ValueError("Parameter 'a' is mapped to by multiple parameters."), + id="Two arguments mapped to a single parameter.", + ), + pytest.param( + Signature( + [ + Parameter("vargs", Parameter.VAR_POSITIONAL), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {"vargs": "a"}, + {}, + ValueError( + "Variable-length positional/keyword parameters must map to each other " + "('vargs' is type VAR_POSITIONAL, but 'a' is type POSITIONAL_ONLY)." + ), + id="Map *args to positional argument.", + ), + pytest.param( + Signature( + [ + Parameter("vargs", Parameter.VAR_POSITIONAL), + ] + ), + Signature( + [ + Parameter("kwarg", Parameter.VAR_KEYWORD), + ] + ), + {"vargs": "kwarg"}, + {}, + ValueError( + "Variable-length positional/keyword parameters must map to each other " + "('vargs' is type VAR_POSITIONAL, but 'kwarg' is type VAR_KEYWORD)." + ), + id="Map *args to **kwargs.", + ), + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + {}, + {}, + ValueError("Some parameters in new_signature are not used: b"), + id="new_signature contains extra parameters.", + ), + pytest.param( + "general_function_signature", + "general_function_signature", + {}, + {}, + ( + { + "posix": "posix", + "posix_def": "posix_def", + "vargs": "vargs", + "kwo": "kwo", + "kwo_def": "kwo_def", + "kwargs": "kwargs", + }, + {}, + ), + id="Can cast to yourself.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), + Signature([Parameter("a", Parameter.KEYWORD_ONLY)]), + {}, + {}, + ({"a": "a"}, {}), + id="Infer identically named parameter (even with type change)", + ), + pytest.param( + Signature([Parameter("args", Parameter.VAR_POSITIONAL)]), + Signature([Parameter("new_args", Parameter.VAR_POSITIONAL)]), + {}, + {}, + ({"args": "new_args"}, {}), + id="Infer VAR_POSITIONAL matching.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), + Signature([]), + {}, + {"a": 10}, + ({}, {"a": 10}), + id="Assign static value to argument without default.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY, default=10)]), + Signature([]), + {}, + {}, + ({}, {"a": 10}), + id="Infer static value from argument default.", + ), + ], +) +def test_signature_can_be_cast( # noqa: PLR0913 + signature_to_convert: Signature, + new_signature: Signature, + old_to_new_names: ParamNameMap, + give_static_value: StaticValues, + expected_output: Exception | tuple[str | None, ParamNameMap, StaticValues], + request, +) -> None: + if isinstance(signature_to_convert, str): + signature_to_convert = request.getfixturevalue(signature_to_convert) + if isinstance(new_signature, str): + new_signature = request.getfixturevalue(new_signature) + + if isinstance(expected_output, Exception): + with pytest.raises( + type(expected_output), match=re.escape(str(expected_output)) + ): + _signature_can_be_cast( + signature_to_convert, + new_signature, + old_to_new_names, + give_static_value, + ) + else: + computed_output = _signature_can_be_cast( + signature_to_convert, + new_signature, + old_to_new_names, + give_static_value, + ) + + assert computed_output == expected_output From 1dddcc7ca1bbdb209c53604dcad56f5a6b0bd748 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:54:34 +0000 Subject: [PATCH 13/38] Use fixtures for convert_signature test --- tests/test_backend/test_convert_signature.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index d6ee1cf..a983269 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -1,5 +1,5 @@ import re -from collections.abc import Iterable +from collections.abc import Callable, Iterable from inspect import Parameter, Signature from typing import Any @@ -7,13 +7,6 @@ from causalprog.backend._convert_signature import convert_signature - -def general_function( - posix, /, posix_def="posix_def", *vargs, kwo, kwo_def="kwo_def", **kwargs -): - return posix, posix_def, vargs, kwo, kwo_def, kwargs - - _kwargs_static_value = {"some": "keyword-arguments"} @@ -81,6 +74,7 @@ def test_convert_signature( posix_for_new_call: Iterable[Any], keyword_for_new_call: dict[str, Any], expected_assignments: dict[str, Any] | Exception, + general_function: Callable, ) -> None: """ To ease the burden of setting up and parametrising this test, From 5b28001ae71f0bf8bbdcfb5929dbe1624887fdeb Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:58:17 +0000 Subject: [PATCH 14/38] remove outdated comment --- src/causalprog/backend/_convert_signature.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 81d9738..6b96419 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -289,10 +289,6 @@ def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType: }, **kwargs_to_pass_on, ) - # We can supply all arguments EXCEPT the variable-positional and positional-only - # arguments as keyword args. - # Positional-only arguments have to come first, followed by the - # variable-positional parameters. fn_args = [fn_kwargs.pop(p_name) for p_name in fn_posix_args] if fn_vargs_param: fn_args.extend(fn_kwargs.pop(fn_vargs_param, [])) From 6aedfcdd3e9a71c0309272cc083ffd9fd069106c Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 14:51:53 +0000 Subject: [PATCH 15/38] Write identity map --- src/causalprog/backend/translator.py | 128 +++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 src/causalprog/backend/translator.py diff --git a/src/causalprog/backend/translator.py b/src/causalprog/backend/translator.py new file mode 100644 index 0000000..0edbb2c --- /dev/null +++ b/src/causalprog/backend/translator.py @@ -0,0 +1,128 @@ +"""Translating backend object syntax to frontend syntax.""" + +from collections.abc import Callable +from dataclasses import dataclass, field +from inspect import Signature +from typing import Any + +from causalprog._abc.backend_agnostic import Backend, BackendAgnostic + +from ._convert_signature import convert_signature +from ._typing import ParamNameMap, ReturnType, StaticValues + + +@dataclass +class Translation: + """ + Helper class for mapping frontend signatures to backend signatures. + + Predominantly a convenience wrapper for working with different backends. + The attributes stored in an instance form the compulsory arguments that + need to be passed to ``convert_signature`` in order to map a backend + function to the frontend syntax. + """ + + target_signature: Signature + param_map: ParamNameMap + frozen_args: StaticValues = field(default_factory=dict) + target_name: str | None = None + + def __post_init__(self) -> None: + self.param_map = dict(self.param_map) + self.frozen_args = dict(self.frozen_args) + self.target_name = str(self.target_name) if self.target_name else None + + if not all( + isinstance(key, str) and isinstance(value, str) + for key, value in self.param_map.items() + ): + msg = "Parameter map must map names to names (str -> str)" + raise ValueError(msg) + if not all(isinstance(key, str) for key in self.frozen_args): + msg = "Frozen args must be specified by name (str)" + raise ValueError(msg) + + def translate(self, fn: Callable[..., ReturnType]) -> Callable[..., ReturnType]: + """Convert a (compatible) callable's signature into the target_signature.""" + return convert_signature( + fn, self.target_signature, self.param_map, self.frozen_args + ) + + +class Translator(BackendAgnostic[Backend]): + """ + Translates the methods of a backend object into frontend syntax. + + A ``Translator`` acts as an intermediary between a backend that is supplied by the + user and the frontend syntax that ``causalprog`` relies on. The default backend of + ``causalprog`` uses a syntax compatible with ``jax``. + + Other backends may not conform to the syntax that ``causalprog`` expects, but + nonetheless may provide the functionality that it requires. A ``Translator`` is able + to make calls to (the relevant methods of) this backend, whilst still conforming to + the frontend syntax of ``causalprog``. + + As an example, suppose that we have a frontend class ``C`` that needs to provide a + method ``do_this``. ``causalprog`` expects ``C`` to provide the functionality + of ``do_this`` via one of its methods, ``C.do_this(*c_args, **c_kwargs)``. + Now suppose that a class ``D`` from a different, external package might also + provides the functionality of ``do_this``, but it is done by calling + ``D.do_this_different(*d_args, **d_kwargs)``, where there is some mapping + ``m: *c_args, **c_kwargs -> *d_args, **d_kwargs``. In such a case, ``causalprog`` + needs to use a ``Translator`` ``T``, rather than ``D`` directly, where + + ``T.do_this(*c_args, **c_kwargs) = D.do_this_different(m(*c_args, **c_kwargs))``. + """ + + translations: dict[str, Callable] + + @staticmethod + def identity(*args: Any, **kwargs: Any) -> tuple[tuple, dict[str, Any]]: # noqa: ANN401 + """Identity map on positional and keyword arguments.""" + return args, kwargs + + def __init__( + self, + native: Backend, + **translations: Translation, + ) -> None: + """ + Translate a backend object into a frontend-compatible object. + + Args: + native (Backend): Backend object that must be translated to support frontend + syntax. + **translations (Translation): Keyword-specified ``Translation``s that map + the methods of ``native`` to the (signatures of the) methods that the + ``_frontend_provides``. Keyword names are interpreted as the name of the + backend method to translate, whilst ``Translation.target_name`` is + interpreted as the name of the frontend method that this backend method + performs the role of. + + """ + super().__init__(backend=native) + + self.translations = {} + for native_name, t in translations.items(): + translated_name = t.target_name if t.target_name else native_name + native_method = getattr(self._backend_obj, native_name) + + if translated_name in self.translations: + msg = f"Method {translated_name} provided twice." + raise ValueError(msg) + self.translations[translated_name] = convert_signature( + native_method, t.target_signature, t.param_map, t.frozen_args + ) + + # Methods without explicit translations are assumed to be the identity map + for method in self._frontend_provides: + if method not in self.translations: + self.translations[method] = self.identity + + self.validate() + + def __getattr__(self, name: str) -> Any: # noqa: ANN401 + # Check for translations before falling back on backend directly. + if name in self.translations: + return self.translations[name] + return super().__getattr__(name) From 2e4def5bb53db8fbf9e37b14d34b4fd58555f900 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 15:06:49 +0000 Subject: [PATCH 16/38] convert_signature now returns a mapping --- src/causalprog/backend/_convert_signature.py | 18 +++++++++--------- tests/test_backend/test_convert_signature.py | 11 ++++++----- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 6b96419..ef19509 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -180,14 +180,14 @@ def convert_signature( new_signature: Signature, old_to_new_names: ParamNameMap, give_static_value: StaticValues, -) -> Callable[..., ReturnType]: +) -> Callable[..., tuple[Any, Any]]: """ Convert the call signature of a function ``fn`` to that of ``new_signature``. This function effectively allows ``fn`` to be called with ``new_signature``. It - returns a new ``Callable`` that uses the ``new_signature``, and returns the result - of ``fn`` after translating the ``new_signature`` back into that of ``fn`` and - making an appropriate call. + returns a new ``Callable`` (denoted ``g``) that maps the parameters of + ``new_signature`` to the (corresponding) parameters of ``fn``. As such, ``fn`` + composed with ``g`` allows for calling ``fn`` with the ``new_signature``. Converting signatures into each other is, in general, not possible. However under certain assumptions and conventions, it can be done. To that end, the following @@ -236,7 +236,7 @@ def convert_signature( information provided. Returns: - Callable: Callable representing ``fn`` with ``new_signature``. + Callable: Callable mapping parameters in ``new_signature`` to those in ``fn``. See Also: _check_variable_length_params: Validation of number of variable-length @@ -270,7 +270,7 @@ def convert_signature( static_kwargs = give_static_value.pop(fn_kwargs_param) give_static_value = dict(give_static_value, **static_kwargs) - def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType: + def new_sig_to_fn_sig(*args: Any, **kwargs: Any) -> tuple[list, dict[str, Any]]: # noqa: ANN401 bound = new_signature.bind(*args, **kwargs) bound.apply_defaults() @@ -292,7 +292,7 @@ def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType: fn_args = [fn_kwargs.pop(p_name) for p_name in fn_posix_args] if fn_vargs_param: fn_args.extend(fn_kwargs.pop(fn_vargs_param, [])) - # Now we can call fn - return fn(*fn_args, **fn_kwargs) + # Arguments are now mapped + return fn_args, fn_kwargs - return fn_with_new_signature + return new_sig_to_fn_sig diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index a983269..f077beb 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -117,18 +117,19 @@ def test_convert_signature( ), ] ) - new_function = convert_signature( + argument_map = convert_signature( general_function, new_signature, param_name_map, give_static_value ) - if isinstance(expected_assignments, Exception): with pytest.raises( type(expected_assignments), match=re.escape(str(expected_assignments)) ): - new_function(*posix_for_new_call, **keyword_for_new_call) + args, kwargs = argument_map(*posix_for_new_call, **keyword_for_new_call) + else: - posix, posix_def, vargs, kwo, kwo_def, kwargs = new_function( - *posix_for_new_call, **keyword_for_new_call + args, kwargs = argument_map(*posix_for_new_call, **keyword_for_new_call) + posix, posix_def, vargs, kwo, kwo_def, kwargs = general_function( + *args, **kwargs ) assert posix == expected_assignments["posix"] assert posix_def == expected_assignments["posix_def"] From 24e7e05e449b35860c9b6d6d470781129963e339 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 15:28:47 +0000 Subject: [PATCH 17/38] Hidden method for calling the backend --- src/causalprog/backend/translator.py | 53 +++++++++++++++------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/causalprog/backend/translator.py b/src/causalprog/backend/translator.py index 0edbb2c..849016e 100644 --- a/src/causalprog/backend/translator.py +++ b/src/causalprog/backend/translator.py @@ -2,15 +2,16 @@ from collections.abc import Callable from dataclasses import dataclass, field -from inspect import Signature +from inspect import signature from typing import Any from causalprog._abc.backend_agnostic import Backend, BackendAgnostic from ._convert_signature import convert_signature -from ._typing import ParamNameMap, ReturnType, StaticValues +from ._typing import ParamNameMap, StaticValues +# TODO: Tests for this guy @dataclass class Translation: """ @@ -22,15 +23,16 @@ class Translation: function to the frontend syntax. """ - target_signature: Signature + backend_name: str + frontend_name: str param_map: ParamNameMap frozen_args: StaticValues = field(default_factory=dict) - target_name: str | None = None def __post_init__(self) -> None: - self.param_map = dict(self.param_map) + self.backend_name = str(self.backend_name) + self.frontend_name = str(self.frontend_name) self.frozen_args = dict(self.frozen_args) - self.target_name = str(self.target_name) if self.target_name else None + self.param_map = dict(self.param_map) if not all( isinstance(key, str) and isinstance(value, str) @@ -42,13 +44,8 @@ def __post_init__(self) -> None: msg = "Frozen args must be specified by name (str)" raise ValueError(msg) - def translate(self, fn: Callable[..., ReturnType]) -> Callable[..., ReturnType]: - """Convert a (compatible) callable's signature into the target_signature.""" - return convert_signature( - fn, self.target_signature, self.param_map, self.frozen_args - ) - +# TODO: tests for this guy after tests for the above guy! class Translator(BackendAgnostic[Backend]): """ Translates the methods of a backend object into frontend syntax. @@ -74,17 +71,18 @@ class Translator(BackendAgnostic[Backend]): ``T.do_this(*c_args, **c_kwargs) = D.do_this_different(m(*c_args, **c_kwargs))``. """ + frontend_to_native_names: dict[str, str] translations: dict[str, Callable] @staticmethod - def identity(*args: Any, **kwargs: Any) -> tuple[tuple, dict[str, Any]]: # noqa: ANN401 + def identity(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any]]: # noqa: ANN401 """Identity map on positional and keyword arguments.""" return args, kwargs def __init__( self, native: Backend, - **translations: Translation, + *translations: Translation, ) -> None: """ Translate a backend object into a frontend-compatible object. @@ -103,16 +101,17 @@ def __init__( super().__init__(backend=native) self.translations = {} - for native_name, t in translations.items(): - translated_name = t.target_name if t.target_name else native_name + self.frontend_to_native_names = {name: name for name in self._frontend_provides} + for t in translations: + native_name = t.backend_name + translated_name = t.frontend_name native_method = getattr(self._backend_obj, native_name) + target_signature = signature(getattr(self, translated_name)) - if translated_name in self.translations: - msg = f"Method {translated_name} provided twice." - raise ValueError(msg) self.translations[translated_name] = convert_signature( - native_method, t.target_signature, t.param_map, t.frozen_args + native_method, target_signature, t.param_map, t.frozen_args ) + self.frontend_to_native_names[translated_name] = native_name # Methods without explicit translations are assumed to be the identity map for method in self._frontend_provides: @@ -121,8 +120,12 @@ def __init__( self.validate() - def __getattr__(self, name: str) -> Any: # noqa: ANN401 - # Check for translations before falling back on backend directly. - if name in self.translations: - return self.translations[name] - return super().__getattr__(name) + def _call_backend_with(self, method: str, *args: Any, **kwargs: Any) -> Any: # noqa:ANN401 + """Translate arguments and then call the backend.""" + backend_method = getattr(self._backend_obj, method) + backend_args, backend_kwargs = self.translations[method](*args, **kwargs) + return backend_method(*backend_args, **backend_kwargs) + + # IDEA NOW is that we could now define + # def sample(*args, **kwargs): + # return self._call_backend_with("sample", *args, **kwargs) From f2aba0b686419479c033d07b0f3cd5e47497bae3 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 11:17:25 +0000 Subject: [PATCH 18/38] Split Translation out just to compartmentalise --- src/causalprog/backend/translation.py | 49 +++++++++++++++ src/causalprog/backend/translator.py | 50 ++------------- tests/test_backend/test_translation.py | 86 ++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 44 deletions(-) create mode 100644 src/causalprog/backend/translation.py create mode 100644 tests/test_backend/test_translation.py diff --git a/src/causalprog/backend/translation.py b/src/causalprog/backend/translation.py new file mode 100644 index 0000000..ec590ab --- /dev/null +++ b/src/causalprog/backend/translation.py @@ -0,0 +1,49 @@ +"""Data structure for storing information mapping backend to frontend.""" + +from dataclasses import dataclass, field + +from ._typing import ParamNameMap, StaticValues + + +@dataclass +class Translation: + """ + Helper class for mapping frontend signatures to backend signatures. + + Attributes: + backend_name (str): Name of the backend method that is being translated into + a frontend method. + frontend_name (str): Name of the frontend method that the backend method will + be used as. + param_map (ParamNameMap): See ``old_to_new_names`` argument to + ``causalprog.backend._convert_signature``. + frozen_args (StaticValues): See ``give_static_value`` argument to + ``causalprog.backend._convert_signature``. + + """ + + backend_name: str + frontend_name: str + param_map: ParamNameMap + frozen_args: StaticValues = field(default_factory=dict) + + def __post_init__(self) -> None: + if not isinstance(self.backend_name, str): + msg = f"backend_name '{self.backend_name}' is not a string." + raise TypeError(msg) + if not isinstance(self.frontend_name, str): + msg = f"frontend_name '{self.frontend_name}' is not a string." + raise TypeError(msg) + + self.frozen_args = dict(self.frozen_args) + self.param_map = dict(self.param_map) + + if not all( + isinstance(key, str) and isinstance(value, str) + for key, value in self.param_map.items() + ): + msg = "Parameter map must map str -> str." + raise TypeError(msg) + if not all(isinstance(key, str) for key in self.frozen_args): + msg = "Frozen args must be specified by name (str)." + raise TypeError(msg) diff --git a/src/causalprog/backend/translator.py b/src/causalprog/backend/translator.py index 849016e..23b2f16 100644 --- a/src/causalprog/backend/translator.py +++ b/src/causalprog/backend/translator.py @@ -1,48 +1,13 @@ """Translating backend object syntax to frontend syntax.""" from collections.abc import Callable -from dataclasses import dataclass, field from inspect import signature from typing import Any from causalprog._abc.backend_agnostic import Backend, BackendAgnostic from ._convert_signature import convert_signature -from ._typing import ParamNameMap, StaticValues - - -# TODO: Tests for this guy -@dataclass -class Translation: - """ - Helper class for mapping frontend signatures to backend signatures. - - Predominantly a convenience wrapper for working with different backends. - The attributes stored in an instance form the compulsory arguments that - need to be passed to ``convert_signature`` in order to map a backend - function to the frontend syntax. - """ - - backend_name: str - frontend_name: str - param_map: ParamNameMap - frozen_args: StaticValues = field(default_factory=dict) - - def __post_init__(self) -> None: - self.backend_name = str(self.backend_name) - self.frontend_name = str(self.frontend_name) - self.frozen_args = dict(self.frozen_args) - self.param_map = dict(self.param_map) - - if not all( - isinstance(key, str) and isinstance(value, str) - for key, value in self.param_map.items() - ): - msg = "Parameter map must map names to names (str -> str)" - raise ValueError(msg) - if not all(isinstance(key, str) for key in self.frozen_args): - msg = "Frozen args must be specified by name (str)" - raise ValueError(msg) +from .translation import Translation # TODO: tests for this guy after tests for the above guy! @@ -81,7 +46,7 @@ def identity(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any] def __init__( self, - native: Backend, + backend: Backend, *translations: Translation, ) -> None: """ @@ -90,15 +55,12 @@ def __init__( Args: native (Backend): Backend object that must be translated to support frontend syntax. - **translations (Translation): Keyword-specified ``Translation``s that map - the methods of ``native`` to the (signatures of the) methods that the - ``_frontend_provides``. Keyword names are interpreted as the name of the - backend method to translate, whilst ``Translation.target_name`` is - interpreted as the name of the frontend method that this backend method - performs the role of. + *translations (Translation): ``Translation``s that map the methods of + ``backend`` to the (signatures of the) methods that the + ``_frontend_provides``. """ - super().__init__(backend=native) + super().__init__(backend=backend) self.translations = {} self.frontend_to_native_names = {name: name for name in self._frontend_provides} diff --git a/tests/test_backend/test_translation.py b/tests/test_backend/test_translation.py new file mode 100644 index 0000000..4cb50ab --- /dev/null +++ b/tests/test_backend/test_translation.py @@ -0,0 +1,86 @@ +import re +from typing import Any + +import pytest + +from causalprog.backend.translation import Translation + + +@pytest.mark.parametrize( + ("constructor_kwargs", "expected"), + [ + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {"0": "0", "1": "1"}, + }, + None, + id="Respect default frozen args.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {}, + "frozen_args": {"0": 0, "1": 3.1415}, + }, + None, + id="frozen_args dict-values can be Any.", + ), + pytest.param( + { + "backend_name": 100, + "frontend_name": "frontend", + "param_map": {"0": "0", "1": "1"}, + }, + TypeError("backend_name '100' is not a string."), + id="Backend name must be string.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": [1, 2, 3], + "param_map": {"0": "0", "1": "1"}, + }, + TypeError("frontend_name '[1, 2, 3]' is not a string."), + id="Frontend name must be string.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {"0": "0", "1": 1}, + }, + TypeError("Parameter map must map str -> str."), + id="Parameter map value is not string.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {0: "0", "1": "1"}, + }, + TypeError("Parameter map must map str -> str."), + id="Parameter map key is not string.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {}, + "frozen_args": {0: 0, "1": "1"}, + }, + TypeError("Frozen args must be specified by name (str)."), + id="frozen_args dict-keys must be str.", + ), + ], +) +def test_translation( + constructor_kwargs: dict[str, Any], expected: None | Exception +) -> None: + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): + Translation(**constructor_kwargs) + else: + Translation(**constructor_kwargs) From f68ee906140aa171614a3f744774e9abe118faa0 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 13:07:40 +0000 Subject: [PATCH 19/38] Tests for Translator itself --- src/causalprog/backend/translator.py | 28 ++++--- tests/test_backend/test_translator.py | 109 ++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 10 deletions(-) create mode 100644 tests/test_backend/test_translator.py diff --git a/src/causalprog/backend/translator.py b/src/causalprog/backend/translator.py index 23b2f16..8d02e6e 100644 --- a/src/causalprog/backend/translator.py +++ b/src/causalprog/backend/translator.py @@ -10,7 +10,6 @@ from .translation import Translation -# TODO: tests for this guy after tests for the above guy! class Translator(BackendAgnostic[Backend]): """ Translates the methods of a backend object into frontend syntax. @@ -46,8 +45,8 @@ def identity(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any] def __init__( self, - backend: Backend, *translations: Translation, + backend: Backend, ) -> None: """ Translate a backend object into a frontend-compatible object. @@ -68,26 +67,35 @@ def __init__( native_name = t.backend_name translated_name = t.frontend_name native_method = getattr(self._backend_obj, native_name) - target_signature = signature(getattr(self, translated_name)) + target_method = getattr(self, translated_name) + target_signature = signature(target_method) self.translations[translated_name] = convert_signature( native_method, target_signature, t.param_map, t.frozen_args ) self.frontend_to_native_names[translated_name] = native_name - # Methods without explicit translations are assumed to be the identity map + # Methods without explicit translations are assumed to be the identity map, + # provided they exist on the backend object. for method in self._frontend_provides: - if method not in self.translations: + method_has_translation = method in self.translations + backend_has_method = hasattr(self._backend_obj, method) + if not (method_has_translation or backend_has_method): + msg = ( + f"No translation provided for {method}, " + "which the backend is lacking." + ) + raise AttributeError(msg) + if not method_has_translation: + # Assume the identity mapping to teh backend method, otherwise. self.translations[method] = self.identity self.validate() def _call_backend_with(self, method: str, *args: Any, **kwargs: Any) -> Any: # noqa:ANN401 """Translate arguments and then call the backend.""" - backend_method = getattr(self._backend_obj, method) + backend_method = getattr( + self._backend_obj, self.frontend_to_native_names[method] + ) backend_args, backend_kwargs = self.translations[method](*args, **kwargs) return backend_method(*backend_args, **backend_kwargs) - - # IDEA NOW is that we could now define - # def sample(*args, **kwargs): - # return self._call_backend_with("sample", *args, **kwargs) diff --git a/tests/test_backend/test_translator.py b/tests/test_backend/test_translator.py new file mode 100644 index 0000000..304496e --- /dev/null +++ b/tests/test_backend/test_translator.py @@ -0,0 +1,109 @@ +import re +from collections.abc import Sequence + +import pytest + +from causalprog.backend.translation import Translation +from causalprog.backend.translator import Translator + + +class BackendObjNeedsNoTranslation: + def frontend_method(self, denominator: float, numerator: float) -> float: + return numerator / denominator + + +class BackendObjNameChangeOnly: + def backend_method(self, denominator: float, numerator: float) -> float: + return numerator / denominator + + +class BackendObjNeedsTranslation: + def backend_method(self, num: float, denom: float) -> float: + return num / denom + + +class BackendObjDropsArg: + def backend_method(self, num: float, denom: float, constant: float) -> float: + return num / denom + constant + + +class TranslatorForTesting(Translator): + @property + def _frontend_provides(self) -> tuple[str, ...]: + return ("frontend_method",) + + def frontend_method(self, denominator: float, numerator: float) -> float: + return self._call_backend_with("frontend_method", denominator, numerator) + + +@pytest.mark.parametrize( + ("translations", "backend", "created_method_is_identity"), + [ + pytest.param( + (), BackendObjNeedsNoTranslation(), True, id="Backend needs no translation." + ), + pytest.param( + ( + Translation( + backend_name="backend_method", + frontend_name="frontend_method", + param_map={}, + ), + ), + BackendObjNameChangeOnly(), + False, + id="Method name change, map is no longer identity.", + ), + pytest.param( + ( + Translation( + backend_name="backend_method", + frontend_name="frontend_method", + param_map={"num": "numerator", "denom": "denominator"}, + ), + ), + BackendObjNeedsTranslation(), + False, + id="Full translation required.", + ), + pytest.param( + ( + Translation( + backend_name="backend_method", + frontend_name="frontend_method", + param_map={"num": "numerator", "denom": "denominator"}, + frozen_args={"constant": 0.0}, + ), + ), + BackendObjDropsArg(), + False, + id="Drop an argument.", + ), + ], +) +def test_creation_and_methods( + translations: Sequence[Translation], backend, *, created_method_is_identity: bool +) -> None: + translator = TranslatorForTesting(*translations, backend=backend) + + assert callable(translator.translations["frontend_method"]) + assert ( + translator.translations["frontend_method"] is translator.identity + ) == created_method_is_identity + + input_denominator = 2.0 + input_numerator = 1.0 + assert translator.frontend_method( + input_denominator, input_numerator + ) == pytest.approx(input_numerator / input_denominator) + + +def test_must_provide_all_frontend_methods() -> None: + # You cannot get away without defining one of the frontend methods + with pytest.raises( + AttributeError, + match=re.escape( + "No translation provided for frontend_method, which the backend is lacking." + ), + ): + TranslatorForTesting(backend=BackendObjNameChangeOnly()) From 6a495a1994e7b2ee67f5014c9136fb95b194b942 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 13:52:47 +0000 Subject: [PATCH 20/38] Cannot multiple inherit from two classes that both define __slots__ --- src/causalprog/_abc/backend_agnostic.py | 1 - src/causalprog/_abc/labelled.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/causalprog/_abc/backend_agnostic.py b/src/causalprog/_abc/backend_agnostic.py index 08aa40d..0d6d663 100644 --- a/src/causalprog/_abc/backend_agnostic.py +++ b/src/causalprog/_abc/backend_agnostic.py @@ -17,7 +17,6 @@ class BackendAgnostic(ABC, Generic[Backend]): calls to the ``_backend_obj`` as necessary. """ - __slots__ = ("_backend_obj",) _backend_obj: Backend def __getattr__(self, name: str) -> Any: # noqa: ANN401 diff --git a/src/causalprog/_abc/labelled.py b/src/causalprog/_abc/labelled.py index 92250bc..c4be7f7 100644 --- a/src/causalprog/_abc/labelled.py +++ b/src/causalprog/_abc/labelled.py @@ -11,7 +11,6 @@ class Labelled(ABC): ``label`` property of the class. """ - __slots__ = ("_label",) _label: str @property From f7732a723de8e39c6ab2a545c2781c59f2bb10c4 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 13:58:49 +0000 Subject: [PATCH 21/38] Rework Distribution to be backend-agnostic and use translators --- src/causalprog/distribution/base.py | 97 ++++++++----------- src/causalprog/distribution/normal.py | 6 +- .../test_different_backends.py | 26 ++++- 3 files changed, 68 insertions(+), 61 deletions(-) diff --git a/src/causalprog/distribution/base.py b/src/causalprog/distribution/base.py index bbd21d6..fba47b4 100644 --- a/src/causalprog/distribution/base.py +++ b/src/causalprog/distribution/base.py @@ -1,77 +1,45 @@ """Base class for backend-agnostic distributions.""" -from collections.abc import Callable -from typing import Generic, TypeVar +from typing import TypeVar from numpy.typing import ArrayLike from causalprog._abc.labelled import Labelled -from causalprog.utils.translator import Translator +from causalprog.backend.translation import Translation +from causalprog.backend.translator import Translator SupportsRNG = TypeVar("SupportsRNG") SupportsSampling = TypeVar("SupportsSampling", bound=object) -class SampleTranslator(Translator): - """ - Translate methods for sampling from distributions. - - The ``Distribution`` class provides a ``sample`` method, that takes ``rng_key`` and - ``sample_shape`` as its arguments. Instances of this class transform the these - arguments to those that a backend distribution expects. - """ - - @property - def _frontend_method(self) -> str: - return "sample" - - @property - def compulsory_frontend_args(self) -> set[str]: - """Arguments that are required by the frontend function.""" - return {"rng_key", "sample_shape"} - - -class Distribution(Generic[SupportsSampling], Labelled): +class Distribution(Translator[SupportsSampling], Labelled): """A (backend-agnostic) distribution that can be sampled from.""" - _dist: SupportsSampling - _backend_translator: SampleTranslator + @property + def _frontend_provides(self) -> tuple[str, ...]: + return ("sample",) @property - def _sample(self) -> Callable[..., ArrayLike]: - """Method for drawing samples from the backend object.""" - return getattr(self._dist, self._backend_translator.backend_method) + def dist(self) -> SupportsSampling: + """Return the object representing the distribution.""" + return self.get_backend() def __init__( - self, - backend_distribution: SupportsSampling, - backend_translator: SampleTranslator | None = None, - *, - label: str = "Distribution", + self, *translations: Translation, backend: SupportsSampling, label: str ) -> None: """ - Create a new Distribution. + Create a new distribution, with a given backend. Args: - backend_distribution (SupportsSampling): Backend object that supports - drawing random samples. - backend_translator (SampleTranslator): Translator object mapping backend - sampling function to frontend arguments. + *translations (Translation): Information for mapping the methods of the + backend object to the frontend methods provided by this class. See + ``causalprog.backend.Translator`` for more details. + backend (SupportsSampling): Backend object that represents the distribution. + label (str): Name or label to attach to the distribution. """ - super().__init__(label=label) - - self._dist = backend_distribution - - # Setup sampling calls, and perform one-time check for compatibility - self._backend_translator = ( - backend_translator if backend_translator is not None else SampleTranslator() - ) - self._backend_translator.validate_compatible(backend_distribution) - - def get_dist(self) -> SupportsSampling: - """Access to the backend distribution.""" - return self._dist + Labelled.__init__(self, label=label) + Translator.__init__(self, *translations, backend=backend) def sample(self, rng_key: SupportsRNG, sample_shape: ArrayLike = ()) -> ArrayLike: """ @@ -85,7 +53,26 @@ def sample(self, rng_key: SupportsRNG, sample_shape: ArrayLike = ()) -> ArrayLik ArrayLike: Randomly-drawn samples from the distribution. """ - args_to_backend = self._backend_translator.translate_args( - rng_key=rng_key, sample_shape=sample_shape - ) - return self._sample(**args_to_backend) + return self._call_backend_with("sample", rng_key, sample_shape) + + +class NativeDistribution(Distribution[SupportsSampling]): + """ + A distribution that uses our native backend. + + These distributions do not require translations, since the backend objects + they use conform to our frontend syntax by design. + """ + + def __init__(self, *, backend: SupportsSampling, label: str) -> None: + """ + Create a new distribution, using a native backend. + + Args: + backend (SupportsSampling): Backend object that represents the distribution. + Must be a native backend object; that is a distribution provided by the + ``causalprog`` package. + label (str): Name or label to attach to the distribution. + + """ + super().__init__(backend=backend, label=label) diff --git a/src/causalprog/distribution/normal.py b/src/causalprog/distribution/normal.py index 349cffa..aa8742f 100644 --- a/src/causalprog/distribution/normal.py +++ b/src/causalprog/distribution/normal.py @@ -7,7 +7,7 @@ from jax import Array as JaxArray from numpy.typing import ArrayLike -from .base import Distribution +from .base import NativeDistribution from .family import DistributionFamily ArrayCompatible = TypeVar("ArrayCompatible", JaxArray, ArrayLike) @@ -26,7 +26,7 @@ def sample(self, rng_key: RNGKey, sample_shape: ArrayLike) -> JaxArray: return jrn.multivariate_normal(rng_key, self.mean, self.cov, shape=sample_shape) -class Normal(Distribution): +class Normal(NativeDistribution): r""" A (possibly multivaraiate) normal distribution, $\mathcal{N}(\mu, \Sigma)$. @@ -58,7 +58,7 @@ def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None: cov (ArrayCompatible): Matrix of covariates, $\Sigma$. """ - super().__init__(_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal") + super().__init__(backend=_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal") class NormalFamily(DistributionFamily): diff --git a/tests/test_distributions/test_different_backends.py b/tests/test_distributions/test_different_backends.py index e93c054..e8711e7 100644 --- a/tests/test_distributions/test_different_backends.py +++ b/tests/test_distributions/test_different_backends.py @@ -4,7 +4,9 @@ import jax.numpy as jnp from numpyro.distributions.continuous import MultivariateNormal -from causalprog.distribution.base import Distribution, SampleTranslator +from causalprog.backend.translation import Translation +from causalprog.distribution.base import Distribution +from causalprog.distribution.normal import Normal def test_different_backends(rng_key) -> None: @@ -22,11 +24,29 @@ def test_different_backends(rng_key) -> None: sample_size = (10, 5) distrax_normal = distrax.MultivariateNormalFullCovariance(mean, cov) - distrax_dist = Distribution(distrax_normal, SampleTranslator(rng_key="seed")) + distrax_dist = Distribution( + Translation( + backend_name="sample", + frontend_name="sample", + param_map={"seed": "rng_key"}, + ), + backend=distrax_normal, + label="Distrax normal", + ) distrax_samples = distrax_dist.sample(rng_key, sample_size) npyo_normal = MultivariateNormal(mean, cov) - npyo_dist = Distribution(npyo_normal, SampleTranslator(rng_key="key")) + npyo_dist = Distribution( + Translation( + backend_name="sample", frontend_name="sample", param_map={"key": "rng_key"} + ), + backend=npyo_normal, + label="NumPyro normal", + ) npyo_samples = npyo_dist.sample(rng_key, sample_size) + native_normal = Normal(mean, cov) + native_samples = native_normal.sample(rng_key, sample_size) + assert jnp.allclose(distrax_samples, npyo_samples) + assert jnp.allclose(distrax_samples, native_samples) From a0ec6e1e65a3904536413eb9e275d25a9c401e4d Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 14:39:24 +0000 Subject: [PATCH 22/38] Fix up distribution families, which now just point to backend-agnostic distributions --- src/causalprog/distribution/family.py | 35 ++++++----------- src/causalprog/distribution/normal.py | 4 +- tests/test_distributions/test_family.py | 52 +++++++++++++++++++++---- 3 files changed, 58 insertions(+), 33 deletions(-) diff --git a/src/causalprog/distribution/family.py b/src/causalprog/distribution/family.py index c3e150d..3d14e7b 100644 --- a/src/causalprog/distribution/family.py +++ b/src/causalprog/distribution/family.py @@ -7,14 +7,13 @@ from causalprog._abc.labelled import Labelled from causalprog.distribution.base import Distribution, SupportsSampling -from causalprog.utils.translator import Translator -CreatesDistribution = TypeVar( - "CreatesDistribution", bound=Callable[..., SupportsSampling] +GenericDistribution = TypeVar( + "GenericDistribution", bound=Distribution[SupportsSampling] ) -class DistributionFamily(Generic[CreatesDistribution], Labelled): +class DistributionFamily(Generic[GenericDistribution], Labelled): r""" A family of ``Distributions``, that share the same parameters. @@ -33,23 +32,12 @@ class DistributionFamily(Generic[CreatesDistribution], Labelled): samples drawn from it. """ - _family: CreatesDistribution - _family_translator: Translator | None - - @property - def _member(self) -> Callable[..., Distribution]: - """Constructor method for family members, given parameters.""" - return lambda *parameters: Distribution( - self._family(*parameters), - backend_translator=self._family_translator, - ) + _family: Callable[..., GenericDistribution] def __init__( self, - backend_family: CreatesDistribution, - backend_translator: Translator | None = None, - *, - family_name: str = "DistributionFamily", + family: Callable[..., GenericDistribution], + label: str, ) -> None: """ Create a new family of distributions. @@ -62,12 +50,13 @@ def __init__( passed to the ``Distribution`` constructor. """ - super().__init__(label=family_name) + super().__init__(label=label) - self._family = backend_family - self._family_translator = backend_translator + self._family = family - def construct(self, *parameters: ArrayLike) -> Distribution: + def construct( + self, *pos_parameters: ArrayLike, **kw_parameters: ArrayLike + ) -> Distribution: """ Create a distribution from an explicit set of parameters. @@ -76,4 +65,4 @@ def construct(self, *parameters: ArrayLike) -> Distribution: passed as sequential arguments. """ - return self._member(*parameters) + return self._family(*pos_parameters, **kw_parameters) diff --git a/src/causalprog/distribution/normal.py b/src/causalprog/distribution/normal.py index aa8742f..2df527d 100644 --- a/src/causalprog/distribution/normal.py +++ b/src/causalprog/distribution/normal.py @@ -61,7 +61,7 @@ def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None: super().__init__(backend=_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal") -class NormalFamily(DistributionFamily): +class NormalFamily(DistributionFamily[Normal]): r""" Constructor class for (possibly multivariate) normal distributions. @@ -74,7 +74,7 @@ class NormalFamily(DistributionFamily): def __init__(self) -> None: """Create a family of normal distributions.""" - super().__init__(Normal, family_name="Normal") + super().__init__(Normal, label="Normal family") def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: r""" diff --git a/tests/test_distributions/test_family.py b/tests/test_distributions/test_family.py index 9a95942..3525144 100644 --- a/tests/test_distributions/test_family.py +++ b/tests/test_distributions/test_family.py @@ -1,8 +1,43 @@ -import distrax +import jax.numpy as jnp import pytest +from distrax import MultivariateNormalFullCovariance as Mvn -from causalprog.distribution.base import SampleTranslator +from causalprog.backend.translation import Translation +from causalprog.distribution.base import Distribution from causalprog.distribution.family import DistributionFamily +from causalprog.distribution.normal import Normal, NormalFamily + + +@pytest.mark.parametrize( + ("n_dim_std_normal"), + [pytest.param(1, id="1D normal"), pytest.param(3, id="3D normal")], + indirect=["n_dim_std_normal"], +) +def test_sampling_consistency(rng_key, n_dim_std_normal) -> None: + """""" + sample_shape = (5, 10) + normal_family = NormalFamily() + + via_family = normal_family.construct(*n_dim_std_normal) + via_standard_class = Normal(*n_dim_std_normal) + + family_samples = via_family.sample(rng_key, sample_shape) + standard_class_samples = via_standard_class.sample(rng_key, sample_shape) + + assert jnp.allclose(family_samples, standard_class_samples) + + +class DistraxNormal(Distribution): + def __init__(self, mean, cov): + super().__init__( + Translation( + backend_name="sample", + frontend_name="sample", + param_map={"seed": "rng_key"}, + ), + backend=Mvn(mean, cov), + label="Distrax normal", + ) @pytest.mark.parametrize( @@ -16,11 +51,12 @@ def test_builder_matches_backend(n_dim_std_normal) -> None: to building via the backend explicitly. """ - mnv = distrax.MultivariateNormalFullCovariance - - mnv_family = DistributionFamily(mnv, SampleTranslator(rng_key="seed")) + mnv_family = DistributionFamily( + DistraxNormal, + label="Distrax normal family", + ) via_family = mnv_family.construct(*n_dim_std_normal) - via_backend = mnv(*n_dim_std_normal) + via_backend = Mvn(*n_dim_std_normal) - assert via_backend.kl_divergence(via_family.get_dist()) == pytest.approx(0.0) - assert via_family.get_dist().kl_divergence(via_backend) == pytest.approx(0.0) + assert via_backend.kl_divergence(via_family.dist) == pytest.approx(0.0) + assert via_family.dist.kl_divergence(via_backend) == pytest.approx(0.0) From a93071beec97d56076a3de3627f99c60f44143d4 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 14:46:30 +0000 Subject: [PATCH 23/38] Remove now-defunct translator --- src/causalprog/utils/__init__.py | 1 - src/causalprog/utils/translator.py | 139 ----------------------------- 2 files changed, 140 deletions(-) delete mode 100644 src/causalprog/utils/__init__.py delete mode 100644 src/causalprog/utils/translator.py diff --git a/src/causalprog/utils/__init__.py b/src/causalprog/utils/__init__.py deleted file mode 100644 index c85c873..0000000 --- a/src/causalprog/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Utility classes and methods.""" diff --git a/src/causalprog/utils/translator.py b/src/causalprog/utils/translator.py deleted file mode 100644 index 0e26eba..0000000 --- a/src/causalprog/utils/translator.py +++ /dev/null @@ -1,139 +0,0 @@ -""" -Helper class to keep the codebase backend-agnostic. - -Our frontend (or user-facing) classes each use a syntax that applies across the package -codebase. By contrast, the various backends that we want to support will have different -syntaxes and call signatures for the functions that we want to support. As such, we need -a helper class that can store this "translation" information, allowing the user to -interact with the package in a standard way but also allowing them to choose their own -backend if desired. -""" - -import inspect -from abc import ABC, abstractmethod -from typing import Any - - -class Translator(ABC): - """ - Maps syntax of a backend function to our frontend syntax. - - Different backends have different syntax for drawing samples from the distributions - they support. In order to map these different syntaxes to our backend-agnostic - framework, we need a container class to map the names we have chosen for our - frontend methods to those used by their corresponding backend method. - - A ``Translator`` allows us to identify whether a user-provided backend object is - compatible with one of our frontend wrapper classes (and thus, call signatures). It - also allows users to write their own translators for any custom backends that we do - not explicitly support. - - The use case for a ``Translator`` is as follows. Suppose that we have a frontend - class ``C`` that needs to provide a method ``do_something``. ``C`` stores a - reference to a backend object ``obj`` that can provide the functionality of - ``do_something`` via one of its methods, ``obj.backend_method``. However, there is - no guarantee that the signature of ``do_something`` maps identically to that of - ``obj.backend_method``. A ``Translator`` allows us to encode a mapping of - ``obj.backend_method``s arguments to those of ``do_something``. - """ - - backend_method: str - corresponding_backend_arg: dict[str, str] - - @property - @abstractmethod - def _frontend_method(self) -> str: - """Name of the frontend method that the backend is to be translated into.""" - - @property - @abstractmethod - def compulsory_frontend_args(self) -> set[str]: - """Arguments that are required by the frontend function.""" - - @property - def compulsory_backend_args(self) -> set[str]: - """Arguments that are required to be taken by the backend function.""" - return { - self.corresponding_backend_arg[arg_name] - for arg_name in self.compulsory_frontend_args - } - - def __init__( - self, backend_method: str | None = None, **front_args_to_back_args: str - ) -> None: - """ - Create a new Translator. - - Args: - backend_method (str): Name of the backend method that the instance - translates. - **front_args_to_back_args (str): Mapping of frontend argument names to the - corresponding backend argument names. - - """ - # Assume backend name is identical to frontend name if not provided explicitly - self.backend_method = ( - backend_method if backend_method else self._frontend_method - ) - - # This should really be immutable after we fill defaults! - self.corresponding_backend_arg = dict(front_args_to_back_args) - # Assume compulsory frontend args that are not given translations - # retain their name in the backend. - for arg in self.compulsory_frontend_args: - if arg not in self.corresponding_backend_arg: - self.corresponding_backend_arg[arg] = arg - - def translate_args(self, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401 - """ - Translate frontend arguments (with values) to backend arguments. - - Essentially transforms frontend keyword arguments into their backend keyword - arguments, preserving the value assigned to each argument. - """ - return { - self.corresponding_backend_arg[arg_name]: arg_value - for arg_name, arg_value in kwargs.items() - } - - def validate_compatible(self, obj: object) -> None: - """ - Determine if ``obj`` provides a compatible backend method. - - ``obj`` must provide a callable whose name matches ``self.backend_method``, - and the callable referenced must take arguments matching the names specified in - ``self.compulsory_backend_args``. - - Args: - obj (object): Object to check possesses a method that can be translated into - frontend syntax. - - """ - # Check that obj does provide a method of matching name - if not hasattr(obj, self.backend_method): - msg = f"{obj} has no method '{self.backend_method}'." - raise AttributeError(msg) - if not callable(getattr(obj, self.backend_method)): - msg = f"'{self.backend_method}' attribute of {obj} is not callable." - raise TypeError(msg) - - # Check that this method will be callable with the information given. - method_params = inspect.signature(getattr(obj, self.backend_method)).parameters - # The arguments that will be passed are actually taken by the method. - for compulsory_arg in self.compulsory_backend_args: - if compulsory_arg not in method_params: - msg = ( - f"'{self.backend_method}' does not " - f"take argument '{compulsory_arg}'." - ) - raise TypeError(msg) - # The method does not _require_ any additional arguments - method_requires = { - name for name, p in method_params.items() if p.default is p.empty - } - if not method_requires.issubset(self.compulsory_backend_args): - args_not_accounted_for = method_requires - self.compulsory_backend_args - raise TypeError( - f"'{self.backend_method}' not provided compulsory arguments " - "(missing " + ", ".join(args_not_accounted_for) + ")" - ) From 689b7982266144e3c65e2018173f85e2cc3693de Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 14:46:42 +0000 Subject: [PATCH 24/38] Fix some outdated docstrings --- src/causalprog/backend/translator.py | 4 ++-- src/causalprog/distribution/family.py | 18 +++++------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/causalprog/backend/translator.py b/src/causalprog/backend/translator.py index 8d02e6e..a90bd92 100644 --- a/src/causalprog/backend/translator.py +++ b/src/causalprog/backend/translator.py @@ -52,8 +52,8 @@ def __init__( Translate a backend object into a frontend-compatible object. Args: - native (Backend): Backend object that must be translated to support frontend - syntax. + backend (Backend): Backend object that must be translated to support + frontend syntax. *translations (Translation): ``Translation``s that map the methods of ``backend`` to the (signatures of the) methods that the ``_frontend_provides``. diff --git a/src/causalprog/distribution/family.py b/src/causalprog/distribution/family.py index 3d14e7b..26a1de6 100644 --- a/src/causalprog/distribution/family.py +++ b/src/causalprog/distribution/family.py @@ -37,17 +37,16 @@ class DistributionFamily(Generic[GenericDistribution], Labelled): def __init__( self, family: Callable[..., GenericDistribution], + *, label: str, ) -> None: """ Create a new family of distributions. Args: - backend_family (CreatesDistribution): Backend callable that assembles the - distribution, given explicit parameter values. Currently, this callable - can only accept the parameters as a sequence of positional arguments. - backend_translator (Translator): ``Translator`` instance that to be - passed to the ``Distribution`` constructor. + family (Callable[..., GenericDistribution]): Backend callable that assembles + a member distribution of this family, from explicit parameter values. + label (str): Name to give to the distribution family. """ super().__init__(label=label) @@ -57,12 +56,5 @@ def __init__( def construct( self, *pos_parameters: ArrayLike, **kw_parameters: ArrayLike ) -> Distribution: - """ - Create a distribution from an explicit set of parameters. - - Args: - *parameters (ArrayLike): Parameters that define a member of this family, - passed as sequential arguments. - - """ + """Create a distribution from an explicit set of parameters.""" return self._family(*pos_parameters, **kw_parameters) From 5c27b0bcbaf8664ecf3eae65409f1969bb717430 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 14:48:32 +0000 Subject: [PATCH 25/38] No longer need to test a deleted file --- tests/test_translator.py | 143 --------------------------------------- 1 file changed, 143 deletions(-) delete mode 100644 tests/test_translator.py diff --git a/tests/test_translator.py b/tests/test_translator.py deleted file mode 100644 index 7197992..0000000 --- a/tests/test_translator.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Tests for the SampleCompatibility class.""" - -import re - -import pytest - -from causalprog.utils.translator import Translator - - -class _TranslatorForTesting(Translator): - @property - def _frontend_method(self) -> str: - """Name of the frontend method that the backend is to be translated into.""" - return "method" - - @property - def compulsory_frontend_args(self) -> set[str]: - """Arguments that are required by the frontend function.""" - return {"arg1", "arg2"} - - -class DummyClass: - """ - Stub class for testing. - - Intended use is to provide a variety of method call signatures, that can be used to - verify whether the ``SampleCompatibility`` class is correctly able to determine if - it will be able to call an underlying method without error. - """ - - @property - def prop(self) -> None: - """Properties are not callable.""" - return - - def __init__(self) -> None: - """Create an instance.""" - return - - def __str__(self) -> str: - """Display, in case the object appears in an error string.""" - return "DummyClass instance" - - def method(self, arg1: int, arg2: int = 0, kwarg1: int = 0) -> int: - """Take 1 compulsory and 2 optional arguments.""" - return arg1 + arg2 + kwarg1 - - -@pytest.fixture -def dummy_class_instance() -> DummyClass: - """Instance of the ``DummyClass`` to use in testing.""" - return DummyClass() - - -@pytest.mark.parametrize( - ("info", "expected_result"), - [ - pytest.param( - _TranslatorForTesting(backend_method="method_does_not_exist"), - AttributeError( - "DummyClass instance has no method 'method_does_not_exist'." - ), - id="Object does not have the backend method.", - ), - pytest.param( - _TranslatorForTesting(backend_method="prop"), - TypeError("'prop' attribute of DummyClass instance is not callable."), - id="Object backend method is not callable.", - ), - pytest.param( - _TranslatorForTesting(arg1="not_an_arg"), - TypeError("'method' does not take argument 'not_an_arg'"), - id="Backend does not take compulsory argument.", - ), - pytest.param( - _TranslatorForTesting("method", arg1="arg2", arg2="kwarg1"), - TypeError("'method' not provided compulsory arguments (missing arg1)"), - id="Backend cannot have unspecified compulsory arguments.", - ), - pytest.param( - _TranslatorForTesting(), - None, - id="Fall back on defaults.", - ), - pytest.param( - _TranslatorForTesting(backend_method="method", arg2="kwarg1"), - None, - id="Match args out-of-order.", - ), - ], -) -def test_validate_compatible( - info: _TranslatorForTesting, - dummy_class_instance: DummyClass, - expected_result: Exception | None, -) -> None: - """ - Test the validate_compatible method. - - Test that a SampleCompatibility instance correctly determines if a given method - of a given object is callable, with the information stored in the instance. - """ - if expected_result is not None: - with pytest.raises( - type(expected_result), match=re.escape(str(expected_result)) - ): - info.validate_compatible(dummy_class_instance) - else: - info.validate_compatible(dummy_class_instance) - - -@pytest.mark.parametrize( - ("translator", "input_kwargs", "expected_kwargs"), - [ - pytest.param( - _TranslatorForTesting(), - {"arg1": 0, "arg2": 1}, - {"arg1": 0, "arg2": 1}, - id="Args unchanged.", - ), - pytest.param( - _TranslatorForTesting(arg1="arg2", arg2="arg1"), - {"arg1": 0, "arg2": 1}, - {"arg1": 1, "arg2": 0}, - id="Order of args is swapped.", - ), - pytest.param( - _TranslatorForTesting(arg2="very_different_name"), - {"arg1": 0, "arg2": 1}, - {"arg1": 0, "very_different_name": 1}, - id="Backend names replaced where necessary.", - ), - ], -) -def test_translation( - translator: _TranslatorForTesting, - input_kwargs: dict[str, str], - expected_kwargs: dict[str, str], -) -> None: - """Test the mapping of (compatible) frontend args to backend args.""" - computed_output = translator.translate_args(**input_kwargs) - - assert computed_output == expected_kwargs From a5d198846ace7a8c5b6a514f18aa9b7f1a8a512d Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:06:00 +0000 Subject: [PATCH 26/38] Pull signature conversions from branch --- src/causalprog/backend/__init__.py | 1 + src/causalprog/backend/_convert_signature.py | 270 +++++++++++++++++++ src/causalprog/backend/_typing.py | 5 + 3 files changed, 276 insertions(+) create mode 100644 src/causalprog/backend/__init__.py create mode 100644 src/causalprog/backend/_convert_signature.py create mode 100644 src/causalprog/backend/_typing.py diff --git a/src/causalprog/backend/__init__.py b/src/causalprog/backend/__init__.py new file mode 100644 index 0000000..e028405 --- /dev/null +++ b/src/causalprog/backend/__init__.py @@ -0,0 +1 @@ +"""Helper functionality for incorporating different backends.""" diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py new file mode 100644 index 0000000..7f87063 --- /dev/null +++ b/src/causalprog/backend/_convert_signature.py @@ -0,0 +1,270 @@ +"""Convert a function signature to a different signature.""" + +import inspect +from collections.abc import Callable +from inspect import Parameter, Signature +from typing import Any + +from ._typing import ParamNameMap, ReturnType, StaticValues + +_VARLENGTH_PARAM_TYPES = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) + + +def _validate_variable_length_parameters( + sig: Signature, +) -> dict[inspect._ParameterKind, str | None]: + """ + Check signature contains at most one variable-length parameter of each kind. + + ``Signature`` objects can contain more than one variable-length parameter, despite + the fact that in practice such a signature cannot exist and be valid Python syntax. + This function checks for such cases, and raises an appropriate error, should they + arise. + + Args: + sig (Signature): Function signature to check for variable-length parameters. + + Returns: + dict[inspect._ParameterKind, str | None]: Mapping of variable-length parameter + kinds to the corresponding parameter name in ``sig``, or to ``None`` if no + parameter of that type exists in the signature. + + """ + named_args: dict[inspect._ParameterKind, str | None] = { + kind: None for kind in _VARLENGTH_PARAM_TYPES + } + for kind in _VARLENGTH_PARAM_TYPES: + possible_parameters = [ + p_name for p_name, p in sig.parameters.items() if p.kind == kind + ] + if len(possible_parameters) > 1: + msg = f"New signature takes more than 1 {kind} argument." + raise ValueError(msg) + if len(possible_parameters) > 0: + named_args[kind] = possible_parameters[0] + return named_args + + +def _signature_can_be_cast( + signature_to_convert: Signature, + new_signature: Signature, + param_name_map: ParamNameMap, + give_static_value: StaticValues, +) -> tuple[ParamNameMap, StaticValues]: + """ + Prepare a signature for conversion to another signature. + + In order to map ``signature_to_convert`` to that of ``new_signature``, the following + assurances are needed: + + - Variable-length parameters in the two signatures are assumed to match (up to name + changes) or be provided explicit defaults. The function will attempt to match + variable-length parameters that are not explicitly matched in the + ``param_name_map``. Note that a signature can have, at most, only one + variable-length positional parameter and one variable-length keyword parameter. + - All parameters WITHOUT DEFAULT VALUES in ``signature_to_convert`` correspond to a + parameter in ``new_signature`` (that may or may not have a default value) OR are + given static values to use, via the ``give_static_value`` argument. + - If ``new_signature`` takes variable-keyword-argument (``**kwargs``), these + arguments are expanded to allow for possible matches to parameters of + ``signature_to_convert``, before passing any remaining parameters after this + unpacked to the variable-keyword-argument of ``signature_to_convert``. + + Args: + signature_to_convert (Signature): Function signature that will be cast to + ``new_signature``. + new_signature (Signature): See the homonymous argument to ``convert_signature``. + param_name_map (ParamNameMap): See the homonymous argument to + ``convert_signature``. + give_static_value (StaticValues): See the homonymous argument to + ``convert_signature``. + + Raises: + ValueError: If the two signatures cannot be cast, even given + the additional information. + + Returns: + ParamNameMap: Mapping of parameter names in the ``signature_to_convert`` to + parameter names in ``new_signature``. Implicit mappings as per function + behaviour are explicitly included in the returned mapping. + StaticValues: Mapping of parameter names in the ``signature_to_convert`` to + static values to assign to these parameters, indicating omission from the + ``new_signature``. Implicit adoption of static values as per function + behaviour are explicitly included in the returned mapping. + + """ + _validate_variable_length_parameters(signature_to_convert) + new_varlength_params = _validate_variable_length_parameters(new_signature) + + param_name_map = dict(param_name_map) + give_static_value = dict(give_static_value) + + new_parameters_accounted_for = set() + + # Check mapping of parameters in old signature to new signature + for p_name, param in signature_to_convert.parameters.items(): + is_explicitly_mapped = p_name in param_name_map + name_is_unchanged = ( + p_name not in param_name_map + and p_name not in param_name_map.values() + and p_name in new_signature.parameters + ) + is_given_static = p_name in give_static_value + can_take_default = param.default is not param.empty + is_varlength_param = param.kind in _VARLENGTH_PARAM_TYPES + mapped_to = None + + if is_explicitly_mapped: + # This parameter is explicitly mapped to another parameter + mapped_to = param_name_map[p_name] + elif name_is_unchanged: + # Parameter is inferred not to change name, having been omitted from the + # explicit mapping. + mapped_to = p_name + param_name_map[p_name] = mapped_to + elif ( + is_varlength_param + and new_varlength_params[param.kind] is not None + and str(new_varlength_params[param.kind]) not in param_name_map.values() + ): + # Automatically map VAR_* parameters to their counterpart, if possible. + mapped_to = str(new_varlength_params[param.kind]) + param_name_map[p_name] = mapped_to + elif is_given_static: + # This parameter is given a static value to use. + continue + elif can_take_default: + # This parameter has a default value in the old signature. + # Since it is not explicitly mapped to another parameter, nor given an + # explicit static value, infer that the default value should be set as the + # static value. + give_static_value[p_name] = param.default + else: + msg = ( + f"Parameter '{p_name}' has no counterpart in new_signature, " + "and does not take a static value." + ) + raise ValueError(msg) + + # Record that any parameter mapped_to in the new_signature is now accounted for, + # to avoid many -> one mappings. + if mapped_to: + if mapped_to in new_parameters_accounted_for: + msg = f"Parameter '{mapped_to}' is mapped to by multiple parameters." + raise ValueError(msg) + # Confirm that variable-length parameters are mapped to variable-length + # parameters (of the same type). + if ( + is_varlength_param + and new_signature.parameters[mapped_to].kind != param.kind + ): + msg = ( + "Variable-length positional/keyword parameters must map to each " + f"other ('{p_name}' is type {param.kind}, but '{mapped_to}' is " + f"type {new_signature.parameters[mapped_to].kind})." + ) + raise ValueError(msg) + + new_parameters_accounted_for.add(param_name_map[p_name]) + + # Confirm all items in new_signature are also accounted for. + unaccounted_new_parameters = ( + set(new_signature.parameters) - new_parameters_accounted_for + ) + if unaccounted_new_parameters: + msg = "Some parameters in new_signature are not used: " + ", ".join( + unaccounted_new_parameters + ) + raise ValueError(msg) + + return param_name_map, give_static_value + + +def convert_signature( + fn: Callable[..., ReturnType], + new_signature: Signature, + old_to_new_names: ParamNameMap, + give_static_value: StaticValues, +) -> Callable[..., ReturnType]: + """ + Convert the call signature of a function ``fn`` to that of ``new_signature``. + + Args: + fn (Callable): Callable object to change the signature of. + new_signature (inspect.Signature): New signature to give to ``fn``. + old_to_new_names (dict[str, str]): Maps the names of parameters in ``fn``s + signature to the corresponding parameter names in the new signature. + Parameter names that do not change can be omitted. Note that parameters that + are to be dropped should be supplied to ``give_static_value`` instead. + give_static_value (dict[str, Any]): Maps names of parameters of ``fn`` to + default values that should be assigned to them. This means that not all + compulsory parameters of ``fn`` have to have a corresponding parameter in + ``new_signature`` - such parameters will use the value assigned to them in + ``give_static_value`` if they are lacking a counterpart parameter in + ``new_signature``. Parameters to ``fn`` that lack a counterpart in + ``new_signature``, and that have default values in ``fn``, will be added + automatically. + + Returns: + Callable: Callable representing ``fn`` with ``new_signature``. + + See Also: + _signature_can_be_cast: Validation method used to check casting is possible. + + """ + fn_signature = inspect.signature(fn) + old_to_new_names, give_static_value = _signature_can_be_cast( + fn_signature, new_signature, old_to_new_names, give_static_value + ) + new_to_old_names = {value: key for key, value in old_to_new_names.items()} + + fn_varlength_params = _validate_variable_length_parameters(fn_signature) + fn_vargs_param = fn_varlength_params[Parameter.VAR_POSITIONAL] + fn_kwargs_param = fn_varlength_params[Parameter.VAR_KEYWORD] + + new_varlength_params = _validate_variable_length_parameters(new_signature) + new_kwargs_param = new_varlength_params[Parameter.VAR_KEYWORD] + + fn_posix_args = [ + p_name + for p_name, param in fn_signature.parameters.items() + if param.kind <= param.POSITIONAL_OR_KEYWORD + ] + + # If fn's VAR_KEYWORD parameter is dropped from the new_signature, + # it must have been given a default value to use. We need to expand + # these values now so that they get passed correctly as keyword arguments. + if fn_kwargs_param and fn_kwargs_param in give_static_value: + static_kwargs = give_static_value.pop(fn_kwargs_param) + give_static_value = dict(give_static_value, **static_kwargs) + + def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType: + bound = new_signature.bind(*args, **kwargs) + bound.apply_defaults() + + all_args_received = bound.arguments + kwargs_to_pass_on = ( + all_args_received.pop(new_kwargs_param, {}) if new_kwargs_param else {} + ) + # Maps the name of a parameter to fn to the value that should be supplied, + # as obtained from the arguments provided to this function. + # Calling dict with give_static_value FIRST is important, as defaults will get + # overwritten by any passed arguments! + fn_kwargs = dict( + give_static_value, + **{ + new_to_old_names[key]: value for key, value in all_args_received.items() + }, + **kwargs_to_pass_on, + ) + # We can supply all arguments EXCEPT the variable-positional and positional-only + # arguments as keyword args. + # Positional-only arguments have to come first, followed by the + # variable-positional parameters. + fn_args = [fn_kwargs.pop(p_name) for p_name in fn_posix_args] + if fn_vargs_param: + fn_args.extend(fn_kwargs.pop(fn_vargs_param, [])) + # Now we can call fn + return fn(*fn_args, **fn_kwargs) + + return fn_with_new_signature diff --git a/src/causalprog/backend/_typing.py b/src/causalprog/backend/_typing.py new file mode 100644 index 0000000..e7ab7fb --- /dev/null +++ b/src/causalprog/backend/_typing.py @@ -0,0 +1,5 @@ +from typing import Any, TypeAlias, TypeVar + +ReturnType = TypeVar("ReturnType") +ParamNameMap: TypeAlias = dict[str, str] +StaticValues: TypeAlias = dict[str, Any] From 9fe5f3d0fa06be9fcd3d679c1cb6f620a197504a Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:06:45 +0000 Subject: [PATCH 27/38] Pull tests for converting signatures --- tests/test_backend/test_convert_signature.py | 375 +++++++++++++++++++ 1 file changed, 375 insertions(+) create mode 100644 tests/test_backend/test_convert_signature.py diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py new file mode 100644 index 0000000..c418545 --- /dev/null +++ b/tests/test_backend/test_convert_signature.py @@ -0,0 +1,375 @@ +import re +from collections.abc import Iterable +from inspect import Parameter, Signature, signature +from typing import Any + +import pytest + +from causalprog.backend._convert_signature import ( + _signature_can_be_cast, + _validate_variable_length_parameters, + convert_signature, +) +from causalprog.backend._typing import ParamNameMap, StaticValues + + +def general_function( + posix, /, posix_def="posix_def", *vargs, kwo, kwo_def="kwo_def", **kwargs +): + return posix, posix_def, vargs, kwo, kwo_def, kwargs + + +@pytest.mark.parametrize( + ("signature", "expected"), + [ + pytest.param( + Signature( + ( + Parameter("vargs1", Parameter.VAR_POSITIONAL), + Parameter("vargs2", Parameter.VAR_POSITIONAL), + ) + ), + ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), + id="Two variable-length positional arguments.", + ), + pytest.param( + Signature( + ( + Parameter("kwargs1", Parameter.VAR_KEYWORD), + Parameter("kwargs2", Parameter.VAR_KEYWORD), + ) + ), + ValueError("New signature takes more than 1 VAR_KEYWORD argument."), + id="Two variable-length keyword arguments.", + ), + pytest.param( + signature(general_function), + {Parameter.VAR_POSITIONAL: "vargs", Parameter.VAR_KEYWORD: "kwargs"}, + id="Valid, but complex, signature.", + ), + pytest.param( + Signature( + ( + Parameter("arg1", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("arg2", Parameter.POSITIONAL_OR_KEYWORD, default=1), + Parameter("vargs1", Parameter.VAR_POSITIONAL), + Parameter("vargs2", Parameter.VAR_POSITIONAL), + Parameter("kwargs1", Parameter.VAR_KEYWORD), + ) + ), + ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), + id="Two variable-length positional arguments, mixed with others.", + ), + ], +) +def test_validate_variable_length_parameters( + signature: Signature, expected: Exception | dict +): + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): + _validate_variable_length_parameters(signature) + else: + returned_names = _validate_variable_length_parameters(signature) + + assert returned_names == expected + + +@pytest.mark.parametrize( + ( + "signature_to_convert", + "new_signature", + "param_name_map", + "give_static_value", + "expected_output", + ), + [ + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {}, + {}, + ValueError( + "Parameter 'b' has no counterpart in new_signature, " + "and does not take a static value." + ), + id="Parameter not matched.", + ), + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {"a": "a", "b": "a"}, + {}, + ValueError("Parameter 'a' is mapped to by multiple parameters."), + id="Two arguments mapped to a single parameter.", + ), + pytest.param( + Signature( + [ + Parameter("vargs", Parameter.VAR_POSITIONAL), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {"vargs": "a"}, + {}, + ValueError( + "Variable-length positional/keyword parameters must map to each other " + "('vargs' is type VAR_POSITIONAL, but 'a' is type POSITIONAL_ONLY)." + ), + id="Map *args to positional argument.", + ), + pytest.param( + Signature( + [ + Parameter("vargs", Parameter.VAR_POSITIONAL), + ] + ), + Signature( + [ + Parameter("kwarg", Parameter.VAR_KEYWORD), + ] + ), + {"vargs": "kwarg"}, + {}, + ValueError( + "Variable-length positional/keyword parameters must map to each other " + "('vargs' is type VAR_POSITIONAL, but 'kwarg' is type VAR_KEYWORD)." + ), + id="Map *args to **kwargs.", + ), + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + {}, + {}, + ValueError("Some parameters in new_signature are not used: b"), + id="new_signature contains extra parameters.", + ), + pytest.param( + signature(general_function), + signature(general_function), + {}, + {}, + ({key: key for key in signature(general_function).parameters}, {}), + id="Can cast to yourself.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), + Signature([Parameter("a", Parameter.KEYWORD_ONLY)]), + {}, + {}, + ({"a": "a"}, {}), + id="Infer identically named parameter (even with type change)", + ), + pytest.param( + Signature([Parameter("args", Parameter.VAR_POSITIONAL)]), + Signature([Parameter("new_args", Parameter.VAR_POSITIONAL)]), + {}, + {}, + ({"args": "new_args"}, {}), + id="Infer VAR_POSITIONAL matching.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), + Signature([]), + {}, + {"a": 10}, + ({}, {"a": 10}), + id="Assign static value to argument without default.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY, default=10)]), + Signature([]), + {}, + {}, + ({}, {"a": 10}), + id="Infer static value from argument default.", + ), + ], +) +def test_signature_can_be_cast( + signature_to_convert: Signature, + new_signature: Signature, + param_name_map: ParamNameMap, + give_static_value: StaticValues, + expected_output: Exception | tuple[str | None, ParamNameMap, StaticValues], +) -> None: + if isinstance(expected_output, Exception): + with pytest.raises( + type(expected_output), match=re.escape(str(expected_output)) + ): + _signature_can_be_cast( + signature_to_convert, + new_signature, + param_name_map, + give_static_value, + ) + else: + computed_output = _signature_can_be_cast( + signature_to_convert, + new_signature, + param_name_map, + give_static_value, + ) + + assert computed_output == expected_output + + +_kwargs_static_value = {"some": "keyword-arguments"} + + +@pytest.mark.parametrize( + ( + "posix_for_new_call", + "keyword_for_new_call", + "expected_assignments", + ), + [ + pytest.param( + [1, 2], + {"kwo_n": 3, "kwo_def_n": 4}, + { + "posix": 3, + "posix_def": 4, + "vargs": (), + "kwo": 1, + "kwo_def": 2, + "kwargs": _kwargs_static_value, + }, + id="No vargs supplied.", + ), + pytest.param( + [1, 2, 10, 11, 12], + {"kwo_n": 3, "kwo_def_n": 4}, + { + "posix": 3, + "posix_def": 4, + "vargs": (10, 11, 12), + "kwo": 1, + "kwo_def": 2, + "kwargs": _kwargs_static_value, + }, + id="Supply vargs.", + ), + pytest.param( + [1], + {"kwo_n": 3}, + { + "posix": 3, + "posix_def": "default_for_kwo_def_n", + "vargs": (), + "kwo": 1, + "kwo_def": "default_for_posix_def_n", + "kwargs": _kwargs_static_value, + }, + id="New default values respected.", + ), + pytest.param( + [1], + {"kwo_n": 3, "extra_kwarg": "not allowed"}, + TypeError("got an unexpected keyword argument 'extra_kwarg'"), + id="kwargs not allowed in new signature.", + ), + pytest.param( + [1, 2], + {"kwo_n": 3, "posix_def_n": 2}, + TypeError("multiple values for argument 'posix_def_n'"), + id="Multiple values for new parameter.", + ), + ], +) +def test_convert_signature( + posix_for_new_call: Iterable[Any], + keyword_for_new_call: dict[str, Any], + expected_assignments: dict[str, Any] | Exception, +) -> None: + """ + To ease the burden of setting up and parametrising this test, + we will always use the general_function signature as the target and source + signature. + + However, the target signature will swap the roles of the positional and keyword + parameters, essentially mapping: + + ``posix, posix_def, *vargs, kwo, kwo_def, **kwargs`` + + to + + ``kwo_n, kwo_def_n, *vargs_n, posix_n, posix_def_n``. + + ``give_static_value`` will give kwargs a default value. + + We can then make calls to this new signature, and since ``general_function`` returns + the arguments it received, we can validate that correct passing of arguments occurs. + """ + param_name_map = { + "posix": "kwo_n", + "posix_def": "kwo_def_n", + "kwo": "posix_n", + "kwo_def": "posix_def_n", + } + give_static_value = {"kwargs": _kwargs_static_value} + new_signature = Signature( + [ + Parameter("posix_n", Parameter.POSITIONAL_ONLY), + Parameter( + "posix_def_n", + Parameter.POSITIONAL_OR_KEYWORD, + default="default_for_posix_def_n", + ), + Parameter("vargs_n", Parameter.VAR_POSITIONAL), + Parameter("kwo_n", Parameter.KEYWORD_ONLY), + Parameter( + "kwo_def_n", Parameter.KEYWORD_ONLY, default="default_for_kwo_def_n" + ), + ] + ) + new_function = convert_signature( + general_function, new_signature, param_name_map, give_static_value + ) + + if isinstance(expected_assignments, Exception): + with pytest.raises( + type(expected_assignments), match=re.escape(str(expected_assignments)) + ): + new_function(*posix_for_new_call, **keyword_for_new_call) + else: + posix, posix_def, vargs, kwo, kwo_def, kwargs = new_function( + *posix_for_new_call, **keyword_for_new_call + ) + assert posix == expected_assignments["posix"] + assert posix_def == expected_assignments["posix_def"] + assert vargs == expected_assignments["vargs"] + assert kwo == expected_assignments["kwo"] + assert kwo_def == expected_assignments["kwo_def"] + assert kwargs == expected_assignments["kwargs"] From 5b2dbae4d2f576013ecd253d798067e7dda09373 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:09:03 +0000 Subject: [PATCH 28/38] Hide hidden variable import behind typehint --- src/causalprog/backend/_convert_signature.py | 8 ++++---- src/causalprog/backend/_typing.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 7f87063..d0a2616 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -5,14 +5,14 @@ from inspect import Parameter, Signature from typing import Any -from ._typing import ParamNameMap, ReturnType, StaticValues +from ._typing import ParamKind, ParamNameMap, ReturnType, StaticValues _VARLENGTH_PARAM_TYPES = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) def _validate_variable_length_parameters( sig: Signature, -) -> dict[inspect._ParameterKind, str | None]: +) -> dict[ParamKind, str | None]: """ Check signature contains at most one variable-length parameter of each kind. @@ -25,12 +25,12 @@ def _validate_variable_length_parameters( sig (Signature): Function signature to check for variable-length parameters. Returns: - dict[inspect._ParameterKind, str | None]: Mapping of variable-length parameter + dict[ParamKind, str | None]: Mapping of variable-length parameter kinds to the corresponding parameter name in ``sig``, or to ``None`` if no parameter of that type exists in the signature. """ - named_args: dict[inspect._ParameterKind, str | None] = { + named_args: dict[ParamKind, str | None] = { kind: None for kind in _VARLENGTH_PARAM_TYPES } for kind in _VARLENGTH_PARAM_TYPES: diff --git a/src/causalprog/backend/_typing.py b/src/causalprog/backend/_typing.py index e7ab7fb..c214967 100644 --- a/src/causalprog/backend/_typing.py +++ b/src/causalprog/backend/_typing.py @@ -1,5 +1,7 @@ +from inspect import _ParameterKind from typing import Any, TypeAlias, TypeVar ReturnType = TypeVar("ReturnType") ParamNameMap: TypeAlias = dict[str, str] +ParamKind: TypeAlias = _ParameterKind StaticValues: TypeAlias = dict[str, Any] From efc931516f553ded868ec73aa4575f8ee80811b0 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:12:16 +0000 Subject: [PATCH 29/38] Tidy vargs and kwargs checker function --- src/causalprog/backend/_convert_signature.py | 25 ++++++++++++-------- tests/test_backend/test_convert_signature.py | 6 ++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index d0a2616..4c76e9e 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -10,16 +10,21 @@ _VARLENGTH_PARAM_TYPES = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) -def _validate_variable_length_parameters( +def _check_variable_length_params( sig: Signature, ) -> dict[ParamKind, str | None]: """ - Check signature contains at most one variable-length parameter of each kind. + Return the names of variable-length parameters in a signature. - ``Signature`` objects can contain more than one variable-length parameter, despite - the fact that in practice such a signature cannot exist and be valid Python syntax. - This function checks for such cases, and raises an appropriate error, should they - arise. + There are two types of variable-length parameters; positional (VAR_POSITIONAL) which + are typically denoted ``*args`` or ``*vargs``, and keyword (VAR_KEYWORD) which are + typically denoted ``**kwargs``. + + ``Signature`` objects can contain more than one variable-length parameter of each + kind, despite the fact that in practice such a signature cannot exist and be valid + Python syntax. This function checks for such cases, and raises an appropriate error, + should they arise. Otherwise, it simply identifies the parameters in ``sig`` which + correspond to these two variable-length parameter kinds. Args: sig (Signature): Function signature to check for variable-length parameters. @@ -93,8 +98,8 @@ def _signature_can_be_cast( behaviour are explicitly included in the returned mapping. """ - _validate_variable_length_parameters(signature_to_convert) - new_varlength_params = _validate_variable_length_parameters(new_signature) + _check_variable_length_params(signature_to_convert) + new_varlength_params = _check_variable_length_params(new_signature) param_name_map = dict(param_name_map) give_static_value = dict(give_static_value) @@ -218,11 +223,11 @@ def convert_signature( ) new_to_old_names = {value: key for key, value in old_to_new_names.items()} - fn_varlength_params = _validate_variable_length_parameters(fn_signature) + fn_varlength_params = _check_variable_length_params(fn_signature) fn_vargs_param = fn_varlength_params[Parameter.VAR_POSITIONAL] fn_kwargs_param = fn_varlength_params[Parameter.VAR_KEYWORD] - new_varlength_params = _validate_variable_length_parameters(new_signature) + new_varlength_params = _check_variable_length_params(new_signature) new_kwargs_param = new_varlength_params[Parameter.VAR_KEYWORD] fn_posix_args = [ diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index c418545..f4463e5 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -6,8 +6,8 @@ import pytest from causalprog.backend._convert_signature import ( + _check_variable_length_params, _signature_can_be_cast, - _validate_variable_length_parameters, convert_signature, ) from causalprog.backend._typing import ParamNameMap, StaticValues @@ -67,9 +67,9 @@ def test_validate_variable_length_parameters( ): if isinstance(expected, Exception): with pytest.raises(type(expected), match=re.escape(str(expected))): - _validate_variable_length_parameters(signature) + _check_variable_length_params(signature) else: - returned_names = _validate_variable_length_parameters(signature) + returned_names = _check_variable_length_params(signature) assert returned_names == expected From 51eaa8bae9cb48af9fb32b2172bc5b9740bd5ff9 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:18:52 +0000 Subject: [PATCH 30/38] Refactor _check tests --- tests/test_backend/conftest.py | 21 ++++++ .../test_check_variable_length_parameters.py | 64 +++++++++++++++++++ tests/test_backend/test_convert_signature.py | 56 ---------------- 3 files changed, 85 insertions(+), 56 deletions(-) create mode 100644 tests/test_backend/conftest.py create mode 100644 tests/test_backend/test_check_variable_length_parameters.py diff --git a/tests/test_backend/conftest.py b/tests/test_backend/conftest.py new file mode 100644 index 0000000..00734e4 --- /dev/null +++ b/tests/test_backend/conftest.py @@ -0,0 +1,21 @@ +from collections.abc import Callable +from inspect import Signature, signature + +import pytest + + +@pytest.fixture +def general_function() -> Callable: + def _general_function( + posix, /, posix_def="posix_def", *vargs, kwo, kwo_def="kwo_def", **kwargs + ): + """Return the provided arguments.""" + return posix, posix_def, vargs, kwo, kwo_def, kwargs + + return _general_function + + +@pytest.fixture +def general_function_signature(general_function: Callable) -> Signature: + """Signature of the ``general_function`` callable.""" + return signature(general_function) diff --git a/tests/test_backend/test_check_variable_length_parameters.py b/tests/test_backend/test_check_variable_length_parameters.py new file mode 100644 index 0000000..cf7093f --- /dev/null +++ b/tests/test_backend/test_check_variable_length_parameters.py @@ -0,0 +1,64 @@ +import re +from inspect import Parameter, Signature + +import pytest + +from causalprog.backend._convert_signature import _check_variable_length_params + + +@pytest.mark.parametrize( + ("signature", "expected"), + [ + pytest.param( + Signature( + ( + Parameter("vargs1", Parameter.VAR_POSITIONAL), + Parameter("vargs2", Parameter.VAR_POSITIONAL), + ) + ), + ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), + id="Two variable-length positional arguments.", + ), + pytest.param( + Signature( + ( + Parameter("kwargs1", Parameter.VAR_KEYWORD), + Parameter("kwargs2", Parameter.VAR_KEYWORD), + ) + ), + ValueError("New signature takes more than 1 VAR_KEYWORD argument."), + id="Two variable-length keyword arguments.", + ), + pytest.param( + "general_function_signature", + {Parameter.VAR_POSITIONAL: "vargs", Parameter.VAR_KEYWORD: "kwargs"}, + id="Valid, but complex, signature.", + ), + pytest.param( + Signature( + ( + Parameter("arg1", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("arg2", Parameter.POSITIONAL_OR_KEYWORD, default=1), + Parameter("vargs1", Parameter.VAR_POSITIONAL), + Parameter("vargs2", Parameter.VAR_POSITIONAL), + Parameter("kwargs1", Parameter.VAR_KEYWORD), + ) + ), + ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), + id="Two variable-length positional arguments, mixed with others.", + ), + ], +) +def test_check_variable_length_parameters( + signature: Signature, expected: Exception | dict, request +): + if isinstance(signature, str): + signature = request.getfixturevalue(signature) + + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): + _check_variable_length_params(signature) + else: + returned_names = _check_variable_length_params(signature) + + assert returned_names == expected diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index f4463e5..826a259 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -6,7 +6,6 @@ import pytest from causalprog.backend._convert_signature import ( - _check_variable_length_params, _signature_can_be_cast, convert_signature, ) @@ -19,61 +18,6 @@ def general_function( return posix, posix_def, vargs, kwo, kwo_def, kwargs -@pytest.mark.parametrize( - ("signature", "expected"), - [ - pytest.param( - Signature( - ( - Parameter("vargs1", Parameter.VAR_POSITIONAL), - Parameter("vargs2", Parameter.VAR_POSITIONAL), - ) - ), - ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), - id="Two variable-length positional arguments.", - ), - pytest.param( - Signature( - ( - Parameter("kwargs1", Parameter.VAR_KEYWORD), - Parameter("kwargs2", Parameter.VAR_KEYWORD), - ) - ), - ValueError("New signature takes more than 1 VAR_KEYWORD argument."), - id="Two variable-length keyword arguments.", - ), - pytest.param( - signature(general_function), - {Parameter.VAR_POSITIONAL: "vargs", Parameter.VAR_KEYWORD: "kwargs"}, - id="Valid, but complex, signature.", - ), - pytest.param( - Signature( - ( - Parameter("arg1", Parameter.POSITIONAL_OR_KEYWORD), - Parameter("arg2", Parameter.POSITIONAL_OR_KEYWORD, default=1), - Parameter("vargs1", Parameter.VAR_POSITIONAL), - Parameter("vargs2", Parameter.VAR_POSITIONAL), - Parameter("kwargs1", Parameter.VAR_KEYWORD), - ) - ), - ValueError("New signature takes more than 1 VAR_POSITIONAL argument."), - id="Two variable-length positional arguments, mixed with others.", - ), - ], -) -def test_validate_variable_length_parameters( - signature: Signature, expected: Exception | dict -): - if isinstance(expected, Exception): - with pytest.raises(type(expected), match=re.escape(str(expected))): - _check_variable_length_params(signature) - else: - returned_names = _check_variable_length_params(signature) - - assert returned_names == expected - - @pytest.mark.parametrize( ( "signature_to_convert", From 7a5d1b6a7d635a50d97befee2be654e795f13b2d Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:41:28 +0000 Subject: [PATCH 31/38] Tidy convert_signature docstring --- src/causalprog/backend/_convert_signature.py | 59 ++++++++++++++++---- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 4c76e9e..15c6586 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -194,27 +194,64 @@ def convert_signature( """ Convert the call signature of a function ``fn`` to that of ``new_signature``. + This function effectively allows ``fn`` to be called with ``new_signature``. It + returns a new ``Callable`` that uses the ``new_signature``, and returns the result + of ``fn`` after translating the ``new_signature`` back into that of ``fn`` and + making an appropriate call. + + Converting signatures into each other is, in general, not possible. However under + certain assumptions and conventions, it can be done. To that end, the following + assumptions are made about ``fn`` and ``new_signature``: + + 1. All parameters to ``fn`` are either; + 1. mapped to one non-variable-length parameter of ``new_signature``, or + 2. provided with a static value to be used in all calls. + 2. If ``fn`` takes a ``VAR_POSITIONAL`` parameter ``*args``, then either + 1. ``new_signature`` must also take a ``VAR_POSITIONAL`` parameter, and this + must map to identically to ``*args``, + 2. ``*args`` is provided with a static value to be used in all calls, and + ``new_signature`` must not take ``VAR_POSITIONAL`` arguments. + 3. If ``fn`` takes a ``VAR_KEYWORD`` parameter ``**kwargs``, then either + 1. ``new_signature`` must also take a ``VAR_KEYWORD`` parameter, and this + must map to identically to ``**kwargs``, + 2. ``**kwargs`` is provided with a static value to be used in all calls, and + ``new_signature`` must not take ``VAR_KEYWORD`` arguments. + + Mapping of parameters is done by name, from the signature of ``fn`` to + ``new_signature``, in the ``old_to_new_names`` argument. + + 4. If a parameter does not change name between the two signatures, it can be omitted + from this mapping and it will be inferred. Note that such a parameter may still + change kind, or adopt a new default value, in the ``new_signature``. + + Parameters can also be "dropped" from ``fn``'s signature in ``new_signature``, by + assigning them static values to be used in all cases. Such static values are given + in the ``give_static_value`` mapping, which maps (names of) parameters of ``fn`` to + a fixed value to be used for that parameter. This means that these parameters do not + need to be mapped to a parameter in ``new_signature``. + + 5. Parameters that have default values in ``fn``, and which are not mapped to a + parameter of ``new_signature``, will adopt their default value as a static value. + Args: fn (Callable): Callable object to change the signature of. new_signature (inspect.Signature): New signature to give to ``fn``. old_to_new_names (dict[str, str]): Maps the names of parameters in ``fn``s - signature to the corresponding parameter names in the new signature. - Parameter names that do not change can be omitted. Note that parameters that - are to be dropped should be supplied to ``give_static_value`` instead. + signature to the corresponding parameter names in ``new_signature``. give_static_value (dict[str, Any]): Maps names of parameters of ``fn`` to - default values that should be assigned to them. This means that not all - compulsory parameters of ``fn`` have to have a corresponding parameter in - ``new_signature`` - such parameters will use the value assigned to them in - ``give_static_value`` if they are lacking a counterpart parameter in - ``new_signature``. Parameters to ``fn`` that lack a counterpart in - ``new_signature``, and that have default values in ``fn``, will be added - automatically. + static values that should be assigned to them. + + Raises: + ValueError: If ``fn``'s signature cannot be cast to ``new_signature``, given the + information provided. Returns: Callable: Callable representing ``fn`` with ``new_signature``. See Also: - _signature_can_be_cast: Validation method used to check casting is possible. + _check_variable_length_params: Validation of number of variable-length + parameters. + _signature_can_be_cast: Validation method used to check signatures can be cast. """ fn_signature = inspect.signature(fn) From 5f62aba5a589befeb1b661a731f5f99f46d9c8d4 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:46:28 +0000 Subject: [PATCH 32/38] Parameter naming and docstrings for _signature_can_be_cast --- src/causalprog/backend/_convert_signature.py | 52 ++++++++------------ tests/test_backend/test_convert_signature.py | 8 +-- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 15c6586..81d9738 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -53,33 +53,20 @@ def _check_variable_length_params( def _signature_can_be_cast( signature_to_convert: Signature, new_signature: Signature, - param_name_map: ParamNameMap, + old_to_new_names: ParamNameMap, give_static_value: StaticValues, ) -> tuple[ParamNameMap, StaticValues]: """ Prepare a signature for conversion to another signature. - In order to map ``signature_to_convert`` to that of ``new_signature``, the following - assurances are needed: - - - Variable-length parameters in the two signatures are assumed to match (up to name - changes) or be provided explicit defaults. The function will attempt to match - variable-length parameters that are not explicitly matched in the - ``param_name_map``. Note that a signature can have, at most, only one - variable-length positional parameter and one variable-length keyword parameter. - - All parameters WITHOUT DEFAULT VALUES in ``signature_to_convert`` correspond to a - parameter in ``new_signature`` (that may or may not have a default value) OR are - given static values to use, via the ``give_static_value`` argument. - - If ``new_signature`` takes variable-keyword-argument (``**kwargs``), these - arguments are expanded to allow for possible matches to parameters of - ``signature_to_convert``, before passing any remaining parameters after this - unpacked to the variable-keyword-argument of ``signature_to_convert``. + This is a helper that handles the validation detailed in ``convert_signature``. + See the docstring of ``convert_signature`` for more details. Args: signature_to_convert (Signature): Function signature that will be cast to ``new_signature``. new_signature (Signature): See the homonymous argument to ``convert_signature``. - param_name_map (ParamNameMap): See the homonymous argument to + old_to_new_names (ParamNameMap): See the homonymous argument to ``convert_signature``. give_static_value (StaticValues): See the homonymous argument to ``convert_signature``. @@ -90,28 +77,31 @@ def _signature_can_be_cast( Returns: ParamNameMap: Mapping of parameter names in the ``signature_to_convert`` to - parameter names in ``new_signature``. Implicit mappings as per function - behaviour are explicitly included in the returned mapping. + parameter names in ``new_signature``. Implicit mappings as per behaviour of + ``convert_signature`` are explicitly included in the returned mapping. StaticValues: Mapping of parameter names in the ``signature_to_convert`` to static values to assign to these parameters, indicating omission from the - ``new_signature``. Implicit adoption of static values as per function - behaviour are explicitly included in the returned mapping. + ``new_signature``. Implicit adoption of static values as per behaviour of + ``convert_signature`` are explicitly included in the returned mapping. + + See Also: + convert_signature: Function for which setup is being performed. """ _check_variable_length_params(signature_to_convert) new_varlength_params = _check_variable_length_params(new_signature) - param_name_map = dict(param_name_map) + old_to_new_names = dict(old_to_new_names) give_static_value = dict(give_static_value) new_parameters_accounted_for = set() # Check mapping of parameters in old signature to new signature for p_name, param in signature_to_convert.parameters.items(): - is_explicitly_mapped = p_name in param_name_map + is_explicitly_mapped = p_name in old_to_new_names name_is_unchanged = ( - p_name not in param_name_map - and p_name not in param_name_map.values() + p_name not in old_to_new_names + and p_name not in old_to_new_names.values() and p_name in new_signature.parameters ) is_given_static = p_name in give_static_value @@ -121,20 +111,20 @@ def _signature_can_be_cast( if is_explicitly_mapped: # This parameter is explicitly mapped to another parameter - mapped_to = param_name_map[p_name] + mapped_to = old_to_new_names[p_name] elif name_is_unchanged: # Parameter is inferred not to change name, having been omitted from the # explicit mapping. mapped_to = p_name - param_name_map[p_name] = mapped_to + old_to_new_names[p_name] = mapped_to elif ( is_varlength_param and new_varlength_params[param.kind] is not None - and str(new_varlength_params[param.kind]) not in param_name_map.values() + and str(new_varlength_params[param.kind]) not in old_to_new_names.values() ): # Automatically map VAR_* parameters to their counterpart, if possible. mapped_to = str(new_varlength_params[param.kind]) - param_name_map[p_name] = mapped_to + old_to_new_names[p_name] = mapped_to elif is_given_static: # This parameter is given a static value to use. continue @@ -170,7 +160,7 @@ def _signature_can_be_cast( ) raise ValueError(msg) - new_parameters_accounted_for.add(param_name_map[p_name]) + new_parameters_accounted_for.add(old_to_new_names[p_name]) # Confirm all items in new_signature are also accounted for. unaccounted_new_parameters = ( @@ -182,7 +172,7 @@ def _signature_can_be_cast( ) raise ValueError(msg) - return param_name_map, give_static_value + return old_to_new_names, give_static_value def convert_signature( diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index 826a259..76fbb0d 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -22,7 +22,7 @@ def general_function( ( "signature_to_convert", "new_signature", - "param_name_map", + "old_to_new_names", "give_static_value", "expected_output", ), @@ -164,7 +164,7 @@ def general_function( def test_signature_can_be_cast( signature_to_convert: Signature, new_signature: Signature, - param_name_map: ParamNameMap, + old_to_new_names: ParamNameMap, give_static_value: StaticValues, expected_output: Exception | tuple[str | None, ParamNameMap, StaticValues], ) -> None: @@ -175,14 +175,14 @@ def test_signature_can_be_cast( _signature_can_be_cast( signature_to_convert, new_signature, - param_name_map, + old_to_new_names, give_static_value, ) else: computed_output = _signature_can_be_cast( signature_to_convert, new_signature, - param_name_map, + old_to_new_names, give_static_value, ) From e4b99a89482b468923aa4e6ebae6696b69e7240c Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:51:24 +0000 Subject: [PATCH 33/38] Refactor _signature_can_be_cast tests --- tests/test_backend/test_convert_signature.py | 179 +--------------- .../test_signature_can_be_cast.py | 194 ++++++++++++++++++ 2 files changed, 196 insertions(+), 177 deletions(-) create mode 100644 tests/test_backend/test_signature_can_be_cast.py diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index 76fbb0d..d6ee1cf 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -1,15 +1,11 @@ import re from collections.abc import Iterable -from inspect import Parameter, Signature, signature +from inspect import Parameter, Signature from typing import Any import pytest -from causalprog.backend._convert_signature import ( - _signature_can_be_cast, - convert_signature, -) -from causalprog.backend._typing import ParamNameMap, StaticValues +from causalprog.backend._convert_signature import convert_signature def general_function( @@ -18,177 +14,6 @@ def general_function( return posix, posix_def, vargs, kwo, kwo_def, kwargs -@pytest.mark.parametrize( - ( - "signature_to_convert", - "new_signature", - "old_to_new_names", - "give_static_value", - "expected_output", - ), - [ - pytest.param( - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - Parameter("b", Parameter.POSITIONAL_ONLY), - ] - ), - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - ] - ), - {}, - {}, - ValueError( - "Parameter 'b' has no counterpart in new_signature, " - "and does not take a static value." - ), - id="Parameter not matched.", - ), - pytest.param( - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - Parameter("b", Parameter.POSITIONAL_ONLY), - ] - ), - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - ] - ), - {"a": "a", "b": "a"}, - {}, - ValueError("Parameter 'a' is mapped to by multiple parameters."), - id="Two arguments mapped to a single parameter.", - ), - pytest.param( - Signature( - [ - Parameter("vargs", Parameter.VAR_POSITIONAL), - ] - ), - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - ] - ), - {"vargs": "a"}, - {}, - ValueError( - "Variable-length positional/keyword parameters must map to each other " - "('vargs' is type VAR_POSITIONAL, but 'a' is type POSITIONAL_ONLY)." - ), - id="Map *args to positional argument.", - ), - pytest.param( - Signature( - [ - Parameter("vargs", Parameter.VAR_POSITIONAL), - ] - ), - Signature( - [ - Parameter("kwarg", Parameter.VAR_KEYWORD), - ] - ), - {"vargs": "kwarg"}, - {}, - ValueError( - "Variable-length positional/keyword parameters must map to each other " - "('vargs' is type VAR_POSITIONAL, but 'kwarg' is type VAR_KEYWORD)." - ), - id="Map *args to **kwargs.", - ), - pytest.param( - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - ] - ), - Signature( - [ - Parameter("a", Parameter.POSITIONAL_ONLY), - Parameter("b", Parameter.POSITIONAL_ONLY), - ] - ), - {}, - {}, - ValueError("Some parameters in new_signature are not used: b"), - id="new_signature contains extra parameters.", - ), - pytest.param( - signature(general_function), - signature(general_function), - {}, - {}, - ({key: key for key in signature(general_function).parameters}, {}), - id="Can cast to yourself.", - ), - pytest.param( - Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), - Signature([Parameter("a", Parameter.KEYWORD_ONLY)]), - {}, - {}, - ({"a": "a"}, {}), - id="Infer identically named parameter (even with type change)", - ), - pytest.param( - Signature([Parameter("args", Parameter.VAR_POSITIONAL)]), - Signature([Parameter("new_args", Parameter.VAR_POSITIONAL)]), - {}, - {}, - ({"args": "new_args"}, {}), - id="Infer VAR_POSITIONAL matching.", - ), - pytest.param( - Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), - Signature([]), - {}, - {"a": 10}, - ({}, {"a": 10}), - id="Assign static value to argument without default.", - ), - pytest.param( - Signature([Parameter("a", Parameter.POSITIONAL_ONLY, default=10)]), - Signature([]), - {}, - {}, - ({}, {"a": 10}), - id="Infer static value from argument default.", - ), - ], -) -def test_signature_can_be_cast( - signature_to_convert: Signature, - new_signature: Signature, - old_to_new_names: ParamNameMap, - give_static_value: StaticValues, - expected_output: Exception | tuple[str | None, ParamNameMap, StaticValues], -) -> None: - if isinstance(expected_output, Exception): - with pytest.raises( - type(expected_output), match=re.escape(str(expected_output)) - ): - _signature_can_be_cast( - signature_to_convert, - new_signature, - old_to_new_names, - give_static_value, - ) - else: - computed_output = _signature_can_be_cast( - signature_to_convert, - new_signature, - old_to_new_names, - give_static_value, - ) - - assert computed_output == expected_output - - _kwargs_static_value = {"some": "keyword-arguments"} diff --git a/tests/test_backend/test_signature_can_be_cast.py b/tests/test_backend/test_signature_can_be_cast.py new file mode 100644 index 0000000..36f0f53 --- /dev/null +++ b/tests/test_backend/test_signature_can_be_cast.py @@ -0,0 +1,194 @@ +import re +from inspect import Parameter, Signature + +import pytest + +from causalprog.backend._convert_signature import _signature_can_be_cast +from causalprog.backend._typing import ParamNameMap, StaticValues + + +@pytest.mark.parametrize( + ( + "signature_to_convert", + "new_signature", + "old_to_new_names", + "give_static_value", + "expected_output", + ), + [ + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {}, + {}, + ValueError( + "Parameter 'b' has no counterpart in new_signature, " + "and does not take a static value." + ), + id="Parameter not matched.", + ), + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {"a": "a", "b": "a"}, + {}, + ValueError("Parameter 'a' is mapped to by multiple parameters."), + id="Two arguments mapped to a single parameter.", + ), + pytest.param( + Signature( + [ + Parameter("vargs", Parameter.VAR_POSITIONAL), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + {"vargs": "a"}, + {}, + ValueError( + "Variable-length positional/keyword parameters must map to each other " + "('vargs' is type VAR_POSITIONAL, but 'a' is type POSITIONAL_ONLY)." + ), + id="Map *args to positional argument.", + ), + pytest.param( + Signature( + [ + Parameter("vargs", Parameter.VAR_POSITIONAL), + ] + ), + Signature( + [ + Parameter("kwarg", Parameter.VAR_KEYWORD), + ] + ), + {"vargs": "kwarg"}, + {}, + ValueError( + "Variable-length positional/keyword parameters must map to each other " + "('vargs' is type VAR_POSITIONAL, but 'kwarg' is type VAR_KEYWORD)." + ), + id="Map *args to **kwargs.", + ), + pytest.param( + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + ] + ), + Signature( + [ + Parameter("a", Parameter.POSITIONAL_ONLY), + Parameter("b", Parameter.POSITIONAL_ONLY), + ] + ), + {}, + {}, + ValueError("Some parameters in new_signature are not used: b"), + id="new_signature contains extra parameters.", + ), + pytest.param( + "general_function_signature", + "general_function_signature", + {}, + {}, + ( + { + "posix": "posix", + "posix_def": "posix_def", + "vargs": "vargs", + "kwo": "kwo", + "kwo_def": "kwo_def", + "kwargs": "kwargs", + }, + {}, + ), + id="Can cast to yourself.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), + Signature([Parameter("a", Parameter.KEYWORD_ONLY)]), + {}, + {}, + ({"a": "a"}, {}), + id="Infer identically named parameter (even with type change)", + ), + pytest.param( + Signature([Parameter("args", Parameter.VAR_POSITIONAL)]), + Signature([Parameter("new_args", Parameter.VAR_POSITIONAL)]), + {}, + {}, + ({"args": "new_args"}, {}), + id="Infer VAR_POSITIONAL matching.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY)]), + Signature([]), + {}, + {"a": 10}, + ({}, {"a": 10}), + id="Assign static value to argument without default.", + ), + pytest.param( + Signature([Parameter("a", Parameter.POSITIONAL_ONLY, default=10)]), + Signature([]), + {}, + {}, + ({}, {"a": 10}), + id="Infer static value from argument default.", + ), + ], +) +def test_signature_can_be_cast( # noqa: PLR0913 + signature_to_convert: Signature, + new_signature: Signature, + old_to_new_names: ParamNameMap, + give_static_value: StaticValues, + expected_output: Exception | tuple[str | None, ParamNameMap, StaticValues], + request, +) -> None: + if isinstance(signature_to_convert, str): + signature_to_convert = request.getfixturevalue(signature_to_convert) + if isinstance(new_signature, str): + new_signature = request.getfixturevalue(new_signature) + + if isinstance(expected_output, Exception): + with pytest.raises( + type(expected_output), match=re.escape(str(expected_output)) + ): + _signature_can_be_cast( + signature_to_convert, + new_signature, + old_to_new_names, + give_static_value, + ) + else: + computed_output = _signature_can_be_cast( + signature_to_convert, + new_signature, + old_to_new_names, + give_static_value, + ) + + assert computed_output == expected_output From 7df9988b6701358f2f2bd264b11ecab7e50de13e Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:54:34 +0000 Subject: [PATCH 34/38] Use fixtures for convert_signature test --- tests/test_backend/test_convert_signature.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index d6ee1cf..a983269 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -1,5 +1,5 @@ import re -from collections.abc import Iterable +from collections.abc import Callable, Iterable from inspect import Parameter, Signature from typing import Any @@ -7,13 +7,6 @@ from causalprog.backend._convert_signature import convert_signature - -def general_function( - posix, /, posix_def="posix_def", *vargs, kwo, kwo_def="kwo_def", **kwargs -): - return posix, posix_def, vargs, kwo, kwo_def, kwargs - - _kwargs_static_value = {"some": "keyword-arguments"} @@ -81,6 +74,7 @@ def test_convert_signature( posix_for_new_call: Iterable[Any], keyword_for_new_call: dict[str, Any], expected_assignments: dict[str, Any] | Exception, + general_function: Callable, ) -> None: """ To ease the burden of setting up and parametrising this test, From 775489f185d66905dd69e95d5a054d9f62a4b7cd Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 20 Mar 2025 11:58:17 +0000 Subject: [PATCH 35/38] remove outdated comment --- src/causalprog/backend/_convert_signature.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 81d9738..6b96419 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -289,10 +289,6 @@ def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType: }, **kwargs_to_pass_on, ) - # We can supply all arguments EXCEPT the variable-positional and positional-only - # arguments as keyword args. - # Positional-only arguments have to come first, followed by the - # variable-positional parameters. fn_args = [fn_kwargs.pop(p_name) for p_name in fn_posix_args] if fn_vargs_param: fn_args.extend(fn_kwargs.pop(fn_vargs_param, [])) From a463a140824667b7b05708e8c5cc44b4f046a50f Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 14:50:35 +0000 Subject: [PATCH 36/38] Ruff correcting ruff... --- tests/test_backend/test_signature_can_be_cast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_backend/test_signature_can_be_cast.py b/tests/test_backend/test_signature_can_be_cast.py index 36f0f53..3da928d 100644 --- a/tests/test_backend/test_signature_can_be_cast.py +++ b/tests/test_backend/test_signature_can_be_cast.py @@ -160,7 +160,7 @@ ), ], ) -def test_signature_can_be_cast( # noqa: PLR0913 +def test_signature_can_be_cast( signature_to_convert: Signature, new_signature: Signature, old_to_new_names: ParamNameMap, From 7cd3806d46c6df1b87aecd4f8abb19b06f649cdc Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 21 Mar 2025 14:58:46 +0000 Subject: [PATCH 37/38] Remove utils import that has now disappeared --- src/causalprog/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/causalprog/__init__.py b/src/causalprog/__init__.py index 902fe83..9023c92 100644 --- a/src/causalprog/__init__.py +++ b/src/causalprog/__init__.py @@ -1,4 +1,4 @@ """causalprog package.""" -from . import algorithms, distribution, graph, utils +from . import algorithms, distribution, graph from ._version import __version__ From e0d22b9c64d9a41d3ff0159bac8d52e959080a6b Mon Sep 17 00:00:00 2001 From: Will Graham <32364977+willGraham01@users.noreply.github.com> Date: Mon, 24 Mar 2025 10:04:55 +0000 Subject: [PATCH 38/38] Apply suggestions from code review --- src/causalprog/backend/translator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/causalprog/backend/translator.py b/src/causalprog/backend/translator.py index a90bd92..75dbfd6 100644 --- a/src/causalprog/backend/translator.py +++ b/src/causalprog/backend/translator.py @@ -87,13 +87,13 @@ def __init__( ) raise AttributeError(msg) if not method_has_translation: - # Assume the identity mapping to teh backend method, otherwise. + # Assume the identity mapping to the backend method, otherwise. self.translations[method] = self.identity self.validate() def _call_backend_with(self, method: str, *args: Any, **kwargs: Any) -> Any: # noqa:ANN401 - """Translate arguments and then call the backend.""" + """Translate arguments, then call the backend.""" backend_method = getattr( self._backend_obj, self.frontend_to_native_names[method] )