diff --git a/src/causalprog/__init__.py b/src/causalprog/__init__.py index 902fe83..9023c92 100644 --- a/src/causalprog/__init__.py +++ b/src/causalprog/__init__.py @@ -1,4 +1,4 @@ """causalprog package.""" -from . import algorithms, distribution, graph, utils +from . import algorithms, distribution, graph from ._version import __version__ diff --git a/src/causalprog/_abc/backend_agnostic.py b/src/causalprog/_abc/backend_agnostic.py index 08aa40d..0d6d663 100644 --- a/src/causalprog/_abc/backend_agnostic.py +++ b/src/causalprog/_abc/backend_agnostic.py @@ -17,7 +17,6 @@ class BackendAgnostic(ABC, Generic[Backend]): calls to the ``_backend_obj`` as necessary. """ - __slots__ = ("_backend_obj",) _backend_obj: Backend def __getattr__(self, name: str) -> Any: # noqa: ANN401 diff --git a/src/causalprog/_abc/labelled.py b/src/causalprog/_abc/labelled.py index 92250bc..c4be7f7 100644 --- a/src/causalprog/_abc/labelled.py +++ b/src/causalprog/_abc/labelled.py @@ -11,7 +11,6 @@ class Labelled(ABC): ``label`` property of the class. """ - __slots__ = ("_label",) _label: str @property diff --git a/src/causalprog/backend/_convert_signature.py b/src/causalprog/backend/_convert_signature.py index 6b96419..ef19509 100644 --- a/src/causalprog/backend/_convert_signature.py +++ b/src/causalprog/backend/_convert_signature.py @@ -180,14 +180,14 @@ def convert_signature( new_signature: Signature, old_to_new_names: ParamNameMap, give_static_value: StaticValues, -) -> Callable[..., ReturnType]: +) -> Callable[..., tuple[Any, Any]]: """ Convert the call signature of a function ``fn`` to that of ``new_signature``. This function effectively allows ``fn`` to be called with ``new_signature``. It - returns a new ``Callable`` that uses the ``new_signature``, and returns the result - of ``fn`` after translating the ``new_signature`` back into that of ``fn`` and - making an appropriate call. + returns a new ``Callable`` (denoted ``g``) that maps the parameters of + ``new_signature`` to the (corresponding) parameters of ``fn``. As such, ``fn`` + composed with ``g`` allows for calling ``fn`` with the ``new_signature``. Converting signatures into each other is, in general, not possible. However under certain assumptions and conventions, it can be done. To that end, the following @@ -236,7 +236,7 @@ def convert_signature( information provided. Returns: - Callable: Callable representing ``fn`` with ``new_signature``. + Callable: Callable mapping parameters in ``new_signature`` to those in ``fn``. See Also: _check_variable_length_params: Validation of number of variable-length @@ -270,7 +270,7 @@ def convert_signature( static_kwargs = give_static_value.pop(fn_kwargs_param) give_static_value = dict(give_static_value, **static_kwargs) - def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType: + def new_sig_to_fn_sig(*args: Any, **kwargs: Any) -> tuple[list, dict[str, Any]]: # noqa: ANN401 bound = new_signature.bind(*args, **kwargs) bound.apply_defaults() @@ -292,7 +292,7 @@ def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType: fn_args = [fn_kwargs.pop(p_name) for p_name in fn_posix_args] if fn_vargs_param: fn_args.extend(fn_kwargs.pop(fn_vargs_param, [])) - # Now we can call fn - return fn(*fn_args, **fn_kwargs) + # Arguments are now mapped + return fn_args, fn_kwargs - return fn_with_new_signature + return new_sig_to_fn_sig diff --git a/src/causalprog/backend/translation.py b/src/causalprog/backend/translation.py new file mode 100644 index 0000000..ec590ab --- /dev/null +++ b/src/causalprog/backend/translation.py @@ -0,0 +1,49 @@ +"""Data structure for storing information mapping backend to frontend.""" + +from dataclasses import dataclass, field + +from ._typing import ParamNameMap, StaticValues + + +@dataclass +class Translation: + """ + Helper class for mapping frontend signatures to backend signatures. + + Attributes: + backend_name (str): Name of the backend method that is being translated into + a frontend method. + frontend_name (str): Name of the frontend method that the backend method will + be used as. + param_map (ParamNameMap): See ``old_to_new_names`` argument to + ``causalprog.backend._convert_signature``. + frozen_args (StaticValues): See ``give_static_value`` argument to + ``causalprog.backend._convert_signature``. + + """ + + backend_name: str + frontend_name: str + param_map: ParamNameMap + frozen_args: StaticValues = field(default_factory=dict) + + def __post_init__(self) -> None: + if not isinstance(self.backend_name, str): + msg = f"backend_name '{self.backend_name}' is not a string." + raise TypeError(msg) + if not isinstance(self.frontend_name, str): + msg = f"frontend_name '{self.frontend_name}' is not a string." + raise TypeError(msg) + + self.frozen_args = dict(self.frozen_args) + self.param_map = dict(self.param_map) + + if not all( + isinstance(key, str) and isinstance(value, str) + for key, value in self.param_map.items() + ): + msg = "Parameter map must map str -> str." + raise TypeError(msg) + if not all(isinstance(key, str) for key in self.frozen_args): + msg = "Frozen args must be specified by name (str)." + raise TypeError(msg) diff --git a/src/causalprog/backend/translator.py b/src/causalprog/backend/translator.py new file mode 100644 index 0000000..75dbfd6 --- /dev/null +++ b/src/causalprog/backend/translator.py @@ -0,0 +1,101 @@ +"""Translating backend object syntax to frontend syntax.""" + +from collections.abc import Callable +from inspect import signature +from typing import Any + +from causalprog._abc.backend_agnostic import Backend, BackendAgnostic + +from ._convert_signature import convert_signature +from .translation import Translation + + +class Translator(BackendAgnostic[Backend]): + """ + Translates the methods of a backend object into frontend syntax. + + A ``Translator`` acts as an intermediary between a backend that is supplied by the + user and the frontend syntax that ``causalprog`` relies on. The default backend of + ``causalprog`` uses a syntax compatible with ``jax``. + + Other backends may not conform to the syntax that ``causalprog`` expects, but + nonetheless may provide the functionality that it requires. A ``Translator`` is able + to make calls to (the relevant methods of) this backend, whilst still conforming to + the frontend syntax of ``causalprog``. + + As an example, suppose that we have a frontend class ``C`` that needs to provide a + method ``do_this``. ``causalprog`` expects ``C`` to provide the functionality + of ``do_this`` via one of its methods, ``C.do_this(*c_args, **c_kwargs)``. + Now suppose that a class ``D`` from a different, external package might also + provides the functionality of ``do_this``, but it is done by calling + ``D.do_this_different(*d_args, **d_kwargs)``, where there is some mapping + ``m: *c_args, **c_kwargs -> *d_args, **d_kwargs``. In such a case, ``causalprog`` + needs to use a ``Translator`` ``T``, rather than ``D`` directly, where + + ``T.do_this(*c_args, **c_kwargs) = D.do_this_different(m(*c_args, **c_kwargs))``. + """ + + frontend_to_native_names: dict[str, str] + translations: dict[str, Callable] + + @staticmethod + def identity(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any]]: # noqa: ANN401 + """Identity map on positional and keyword arguments.""" + return args, kwargs + + def __init__( + self, + *translations: Translation, + backend: Backend, + ) -> None: + """ + Translate a backend object into a frontend-compatible object. + + Args: + backend (Backend): Backend object that must be translated to support + frontend syntax. + *translations (Translation): ``Translation``s that map the methods of + ``backend`` to the (signatures of the) methods that the + ``_frontend_provides``. + + """ + super().__init__(backend=backend) + + self.translations = {} + self.frontend_to_native_names = {name: name for name in self._frontend_provides} + for t in translations: + native_name = t.backend_name + translated_name = t.frontend_name + native_method = getattr(self._backend_obj, native_name) + target_method = getattr(self, translated_name) + target_signature = signature(target_method) + + self.translations[translated_name] = convert_signature( + native_method, target_signature, t.param_map, t.frozen_args + ) + self.frontend_to_native_names[translated_name] = native_name + + # Methods without explicit translations are assumed to be the identity map, + # provided they exist on the backend object. + for method in self._frontend_provides: + method_has_translation = method in self.translations + backend_has_method = hasattr(self._backend_obj, method) + if not (method_has_translation or backend_has_method): + msg = ( + f"No translation provided for {method}, " + "which the backend is lacking." + ) + raise AttributeError(msg) + if not method_has_translation: + # Assume the identity mapping to the backend method, otherwise. + self.translations[method] = self.identity + + self.validate() + + def _call_backend_with(self, method: str, *args: Any, **kwargs: Any) -> Any: # noqa:ANN401 + """Translate arguments, then call the backend.""" + backend_method = getattr( + self._backend_obj, self.frontend_to_native_names[method] + ) + backend_args, backend_kwargs = self.translations[method](*args, **kwargs) + return backend_method(*backend_args, **backend_kwargs) diff --git a/src/causalprog/distribution/base.py b/src/causalprog/distribution/base.py index bbd21d6..fba47b4 100644 --- a/src/causalprog/distribution/base.py +++ b/src/causalprog/distribution/base.py @@ -1,77 +1,45 @@ """Base class for backend-agnostic distributions.""" -from collections.abc import Callable -from typing import Generic, TypeVar +from typing import TypeVar from numpy.typing import ArrayLike from causalprog._abc.labelled import Labelled -from causalprog.utils.translator import Translator +from causalprog.backend.translation import Translation +from causalprog.backend.translator import Translator SupportsRNG = TypeVar("SupportsRNG") SupportsSampling = TypeVar("SupportsSampling", bound=object) -class SampleTranslator(Translator): - """ - Translate methods for sampling from distributions. - - The ``Distribution`` class provides a ``sample`` method, that takes ``rng_key`` and - ``sample_shape`` as its arguments. Instances of this class transform the these - arguments to those that a backend distribution expects. - """ - - @property - def _frontend_method(self) -> str: - return "sample" - - @property - def compulsory_frontend_args(self) -> set[str]: - """Arguments that are required by the frontend function.""" - return {"rng_key", "sample_shape"} - - -class Distribution(Generic[SupportsSampling], Labelled): +class Distribution(Translator[SupportsSampling], Labelled): """A (backend-agnostic) distribution that can be sampled from.""" - _dist: SupportsSampling - _backend_translator: SampleTranslator + @property + def _frontend_provides(self) -> tuple[str, ...]: + return ("sample",) @property - def _sample(self) -> Callable[..., ArrayLike]: - """Method for drawing samples from the backend object.""" - return getattr(self._dist, self._backend_translator.backend_method) + def dist(self) -> SupportsSampling: + """Return the object representing the distribution.""" + return self.get_backend() def __init__( - self, - backend_distribution: SupportsSampling, - backend_translator: SampleTranslator | None = None, - *, - label: str = "Distribution", + self, *translations: Translation, backend: SupportsSampling, label: str ) -> None: """ - Create a new Distribution. + Create a new distribution, with a given backend. Args: - backend_distribution (SupportsSampling): Backend object that supports - drawing random samples. - backend_translator (SampleTranslator): Translator object mapping backend - sampling function to frontend arguments. + *translations (Translation): Information for mapping the methods of the + backend object to the frontend methods provided by this class. See + ``causalprog.backend.Translator`` for more details. + backend (SupportsSampling): Backend object that represents the distribution. + label (str): Name or label to attach to the distribution. """ - super().__init__(label=label) - - self._dist = backend_distribution - - # Setup sampling calls, and perform one-time check for compatibility - self._backend_translator = ( - backend_translator if backend_translator is not None else SampleTranslator() - ) - self._backend_translator.validate_compatible(backend_distribution) - - def get_dist(self) -> SupportsSampling: - """Access to the backend distribution.""" - return self._dist + Labelled.__init__(self, label=label) + Translator.__init__(self, *translations, backend=backend) def sample(self, rng_key: SupportsRNG, sample_shape: ArrayLike = ()) -> ArrayLike: """ @@ -85,7 +53,26 @@ def sample(self, rng_key: SupportsRNG, sample_shape: ArrayLike = ()) -> ArrayLik ArrayLike: Randomly-drawn samples from the distribution. """ - args_to_backend = self._backend_translator.translate_args( - rng_key=rng_key, sample_shape=sample_shape - ) - return self._sample(**args_to_backend) + return self._call_backend_with("sample", rng_key, sample_shape) + + +class NativeDistribution(Distribution[SupportsSampling]): + """ + A distribution that uses our native backend. + + These distributions do not require translations, since the backend objects + they use conform to our frontend syntax by design. + """ + + def __init__(self, *, backend: SupportsSampling, label: str) -> None: + """ + Create a new distribution, using a native backend. + + Args: + backend (SupportsSampling): Backend object that represents the distribution. + Must be a native backend object; that is a distribution provided by the + ``causalprog`` package. + label (str): Name or label to attach to the distribution. + + """ + super().__init__(backend=backend, label=label) diff --git a/src/causalprog/distribution/family.py b/src/causalprog/distribution/family.py index 3ee2175..93d81ae 100644 --- a/src/causalprog/distribution/family.py +++ b/src/causalprog/distribution/family.py @@ -7,14 +7,13 @@ from causalprog._abc.labelled import Labelled from causalprog.distribution.base import Distribution, SupportsSampling -from causalprog.utils.translator import Translator -CreatesDistribution = TypeVar( - "CreatesDistribution", bound=Callable[..., SupportsSampling] +GenericDistribution = TypeVar( + "GenericDistribution", bound=Distribution[SupportsSampling] ) -class DistributionFamily(Generic[CreatesDistribution], Labelled): +class DistributionFamily(Generic[GenericDistribution], Labelled): r""" A family of ``Distributions``, that share the same parameters. @@ -33,47 +32,38 @@ class DistributionFamily(Generic[CreatesDistribution], Labelled): samples drawn from it. """ - _family: CreatesDistribution - _family_translator: Translator | None - - @property - def _member(self) -> Callable[..., Distribution]: - """Constructor method for family members, given parameters.""" - return lambda **parameters: Distribution( - self._family(**parameters), - backend_translator=self._family_translator, - ) + _family: Callable[..., GenericDistribution] def __init__( self, - backend_family: CreatesDistribution, - backend_translator: Translator | None = None, + family: Callable[..., GenericDistribution], *, - family_name: str = "DistributionFamily", + label: str, ) -> None: """ Create a new family of distributions. Args: - backend_family (CreatesDistribution): Backend callable that assembles the - distribution, given explicit parameter values. Currently, this callable - can only accept the parameters as a sequence of positional arguments. - backend_translator (Translator): ``Translator`` instance that to be - passed to the ``Distribution`` constructor. + family (Callable[..., GenericDistribution]): Backend callable that assembles + a member distribution of this family, from explicit parameter values. + label (str): Name to give to the distribution family. """ - super().__init__(label=family_name) + super().__init__(label=label) - self._family = backend_family - self._family_translator = backend_translator + self._family = family - def construct(self, **parameters: ArrayLike) -> Distribution: + def construct( + self, *pos_parameters: ArrayLike, **kw_parameters: ArrayLike + ) -> Distribution: """ Create a distribution from an explicit set of parameters. Args: - **parameters (ArrayLike): Parameters that define a member of this family, - passed as sequential arguments. + *pos_parameters (ArrayLike): Positional parameter values that define a + member of this family. + **kw_parameters (ArrayLike): Keyword parameter values that define a member + of this family. """ - return self._member(**parameters) + return self._family(*pos_parameters, **kw_parameters) diff --git a/src/causalprog/distribution/normal.py b/src/causalprog/distribution/normal.py index a8dfbb6..9140ae5 100644 --- a/src/causalprog/distribution/normal.py +++ b/src/causalprog/distribution/normal.py @@ -7,7 +7,7 @@ from jax import Array as JaxArray from numpy.typing import ArrayLike -from .base import Distribution +from .base import NativeDistribution from .family import DistributionFamily ArrayCompatible = TypeVar("ArrayCompatible", JaxArray, ArrayLike) @@ -26,7 +26,7 @@ def sample(self, rng_key: RNGKey, sample_shape: ArrayLike) -> JaxArray: return jrn.multivariate_normal(rng_key, self.mean, self.cov, shape=sample_shape) -class Normal(Distribution): +class Normal(NativeDistribution): r""" A (possibly multivaraiate) normal distribution, $\mathcal{N}(\mu, \Sigma)$. @@ -60,10 +60,10 @@ def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None: """ mean = jnp.atleast_1d(mean) cov = jnp.atleast_2d(cov) - super().__init__(_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal") + super().__init__(backend=_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal") -class NormalFamily(DistributionFamily): +class NormalFamily(DistributionFamily[Normal]): r""" Constructor class for (possibly multivariate) normal distributions. @@ -76,7 +76,7 @@ class NormalFamily(DistributionFamily): def __init__(self) -> None: """Create a family of normal distributions.""" - super().__init__(Normal, family_name="Normal") + super().__init__(Normal, label="Normal family") def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003 r""" diff --git a/src/causalprog/utils/__init__.py b/src/causalprog/utils/__init__.py deleted file mode 100644 index c85c873..0000000 --- a/src/causalprog/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Utility classes and methods.""" diff --git a/src/causalprog/utils/translator.py b/src/causalprog/utils/translator.py deleted file mode 100644 index 0e26eba..0000000 --- a/src/causalprog/utils/translator.py +++ /dev/null @@ -1,139 +0,0 @@ -""" -Helper class to keep the codebase backend-agnostic. - -Our frontend (or user-facing) classes each use a syntax that applies across the package -codebase. By contrast, the various backends that we want to support will have different -syntaxes and call signatures for the functions that we want to support. As such, we need -a helper class that can store this "translation" information, allowing the user to -interact with the package in a standard way but also allowing them to choose their own -backend if desired. -""" - -import inspect -from abc import ABC, abstractmethod -from typing import Any - - -class Translator(ABC): - """ - Maps syntax of a backend function to our frontend syntax. - - Different backends have different syntax for drawing samples from the distributions - they support. In order to map these different syntaxes to our backend-agnostic - framework, we need a container class to map the names we have chosen for our - frontend methods to those used by their corresponding backend method. - - A ``Translator`` allows us to identify whether a user-provided backend object is - compatible with one of our frontend wrapper classes (and thus, call signatures). It - also allows users to write their own translators for any custom backends that we do - not explicitly support. - - The use case for a ``Translator`` is as follows. Suppose that we have a frontend - class ``C`` that needs to provide a method ``do_something``. ``C`` stores a - reference to a backend object ``obj`` that can provide the functionality of - ``do_something`` via one of its methods, ``obj.backend_method``. However, there is - no guarantee that the signature of ``do_something`` maps identically to that of - ``obj.backend_method``. A ``Translator`` allows us to encode a mapping of - ``obj.backend_method``s arguments to those of ``do_something``. - """ - - backend_method: str - corresponding_backend_arg: dict[str, str] - - @property - @abstractmethod - def _frontend_method(self) -> str: - """Name of the frontend method that the backend is to be translated into.""" - - @property - @abstractmethod - def compulsory_frontend_args(self) -> set[str]: - """Arguments that are required by the frontend function.""" - - @property - def compulsory_backend_args(self) -> set[str]: - """Arguments that are required to be taken by the backend function.""" - return { - self.corresponding_backend_arg[arg_name] - for arg_name in self.compulsory_frontend_args - } - - def __init__( - self, backend_method: str | None = None, **front_args_to_back_args: str - ) -> None: - """ - Create a new Translator. - - Args: - backend_method (str): Name of the backend method that the instance - translates. - **front_args_to_back_args (str): Mapping of frontend argument names to the - corresponding backend argument names. - - """ - # Assume backend name is identical to frontend name if not provided explicitly - self.backend_method = ( - backend_method if backend_method else self._frontend_method - ) - - # This should really be immutable after we fill defaults! - self.corresponding_backend_arg = dict(front_args_to_back_args) - # Assume compulsory frontend args that are not given translations - # retain their name in the backend. - for arg in self.compulsory_frontend_args: - if arg not in self.corresponding_backend_arg: - self.corresponding_backend_arg[arg] = arg - - def translate_args(self, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401 - """ - Translate frontend arguments (with values) to backend arguments. - - Essentially transforms frontend keyword arguments into their backend keyword - arguments, preserving the value assigned to each argument. - """ - return { - self.corresponding_backend_arg[arg_name]: arg_value - for arg_name, arg_value in kwargs.items() - } - - def validate_compatible(self, obj: object) -> None: - """ - Determine if ``obj`` provides a compatible backend method. - - ``obj`` must provide a callable whose name matches ``self.backend_method``, - and the callable referenced must take arguments matching the names specified in - ``self.compulsory_backend_args``. - - Args: - obj (object): Object to check possesses a method that can be translated into - frontend syntax. - - """ - # Check that obj does provide a method of matching name - if not hasattr(obj, self.backend_method): - msg = f"{obj} has no method '{self.backend_method}'." - raise AttributeError(msg) - if not callable(getattr(obj, self.backend_method)): - msg = f"'{self.backend_method}' attribute of {obj} is not callable." - raise TypeError(msg) - - # Check that this method will be callable with the information given. - method_params = inspect.signature(getattr(obj, self.backend_method)).parameters - # The arguments that will be passed are actually taken by the method. - for compulsory_arg in self.compulsory_backend_args: - if compulsory_arg not in method_params: - msg = ( - f"'{self.backend_method}' does not " - f"take argument '{compulsory_arg}'." - ) - raise TypeError(msg) - # The method does not _require_ any additional arguments - method_requires = { - name for name, p in method_params.items() if p.default is p.empty - } - if not method_requires.issubset(self.compulsory_backend_args): - args_not_accounted_for = method_requires - self.compulsory_backend_args - raise TypeError( - f"'{self.backend_method}' not provided compulsory arguments " - "(missing " + ", ".join(args_not_accounted_for) + ")" - ) diff --git a/tests/test_backend/test_convert_signature.py b/tests/test_backend/test_convert_signature.py index a983269..7964317 100644 --- a/tests/test_backend/test_convert_signature.py +++ b/tests/test_backend/test_convert_signature.py @@ -117,18 +117,18 @@ def test_convert_signature( ), ] ) - new_function = convert_signature( + argument_map = convert_signature( general_function, new_signature, param_name_map, give_static_value ) - if isinstance(expected_assignments, Exception): with pytest.raises( type(expected_assignments), match=re.escape(str(expected_assignments)) ): - new_function(*posix_for_new_call, **keyword_for_new_call) + args, kwargs = argument_map(*posix_for_new_call, **keyword_for_new_call) else: - posix, posix_def, vargs, kwo, kwo_def, kwargs = new_function( - *posix_for_new_call, **keyword_for_new_call + args, kwargs = argument_map(*posix_for_new_call, **keyword_for_new_call) + posix, posix_def, vargs, kwo, kwo_def, kwargs = general_function( + *args, **kwargs ) assert posix == expected_assignments["posix"] assert posix_def == expected_assignments["posix_def"] diff --git a/tests/test_backend/test_translation.py b/tests/test_backend/test_translation.py new file mode 100644 index 0000000..4cb50ab --- /dev/null +++ b/tests/test_backend/test_translation.py @@ -0,0 +1,86 @@ +import re +from typing import Any + +import pytest + +from causalprog.backend.translation import Translation + + +@pytest.mark.parametrize( + ("constructor_kwargs", "expected"), + [ + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {"0": "0", "1": "1"}, + }, + None, + id="Respect default frozen args.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {}, + "frozen_args": {"0": 0, "1": 3.1415}, + }, + None, + id="frozen_args dict-values can be Any.", + ), + pytest.param( + { + "backend_name": 100, + "frontend_name": "frontend", + "param_map": {"0": "0", "1": "1"}, + }, + TypeError("backend_name '100' is not a string."), + id="Backend name must be string.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": [1, 2, 3], + "param_map": {"0": "0", "1": "1"}, + }, + TypeError("frontend_name '[1, 2, 3]' is not a string."), + id="Frontend name must be string.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {"0": "0", "1": 1}, + }, + TypeError("Parameter map must map str -> str."), + id="Parameter map value is not string.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {0: "0", "1": "1"}, + }, + TypeError("Parameter map must map str -> str."), + id="Parameter map key is not string.", + ), + pytest.param( + { + "backend_name": "backend", + "frontend_name": "frontend", + "param_map": {}, + "frozen_args": {0: 0, "1": "1"}, + }, + TypeError("Frozen args must be specified by name (str)."), + id="frozen_args dict-keys must be str.", + ), + ], +) +def test_translation( + constructor_kwargs: dict[str, Any], expected: None | Exception +) -> None: + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): + Translation(**constructor_kwargs) + else: + Translation(**constructor_kwargs) diff --git a/tests/test_backend/test_translator.py b/tests/test_backend/test_translator.py new file mode 100644 index 0000000..304496e --- /dev/null +++ b/tests/test_backend/test_translator.py @@ -0,0 +1,109 @@ +import re +from collections.abc import Sequence + +import pytest + +from causalprog.backend.translation import Translation +from causalprog.backend.translator import Translator + + +class BackendObjNeedsNoTranslation: + def frontend_method(self, denominator: float, numerator: float) -> float: + return numerator / denominator + + +class BackendObjNameChangeOnly: + def backend_method(self, denominator: float, numerator: float) -> float: + return numerator / denominator + + +class BackendObjNeedsTranslation: + def backend_method(self, num: float, denom: float) -> float: + return num / denom + + +class BackendObjDropsArg: + def backend_method(self, num: float, denom: float, constant: float) -> float: + return num / denom + constant + + +class TranslatorForTesting(Translator): + @property + def _frontend_provides(self) -> tuple[str, ...]: + return ("frontend_method",) + + def frontend_method(self, denominator: float, numerator: float) -> float: + return self._call_backend_with("frontend_method", denominator, numerator) + + +@pytest.mark.parametrize( + ("translations", "backend", "created_method_is_identity"), + [ + pytest.param( + (), BackendObjNeedsNoTranslation(), True, id="Backend needs no translation." + ), + pytest.param( + ( + Translation( + backend_name="backend_method", + frontend_name="frontend_method", + param_map={}, + ), + ), + BackendObjNameChangeOnly(), + False, + id="Method name change, map is no longer identity.", + ), + pytest.param( + ( + Translation( + backend_name="backend_method", + frontend_name="frontend_method", + param_map={"num": "numerator", "denom": "denominator"}, + ), + ), + BackendObjNeedsTranslation(), + False, + id="Full translation required.", + ), + pytest.param( + ( + Translation( + backend_name="backend_method", + frontend_name="frontend_method", + param_map={"num": "numerator", "denom": "denominator"}, + frozen_args={"constant": 0.0}, + ), + ), + BackendObjDropsArg(), + False, + id="Drop an argument.", + ), + ], +) +def test_creation_and_methods( + translations: Sequence[Translation], backend, *, created_method_is_identity: bool +) -> None: + translator = TranslatorForTesting(*translations, backend=backend) + + assert callable(translator.translations["frontend_method"]) + assert ( + translator.translations["frontend_method"] is translator.identity + ) == created_method_is_identity + + input_denominator = 2.0 + input_numerator = 1.0 + assert translator.frontend_method( + input_denominator, input_numerator + ) == pytest.approx(input_numerator / input_denominator) + + +def test_must_provide_all_frontend_methods() -> None: + # You cannot get away without defining one of the frontend methods + with pytest.raises( + AttributeError, + match=re.escape( + "No translation provided for frontend_method, which the backend is lacking." + ), + ): + TranslatorForTesting(backend=BackendObjNameChangeOnly()) diff --git a/tests/test_distributions/test_different_backends.py b/tests/test_distributions/test_different_backends.py index e93c054..e8711e7 100644 --- a/tests/test_distributions/test_different_backends.py +++ b/tests/test_distributions/test_different_backends.py @@ -4,7 +4,9 @@ import jax.numpy as jnp from numpyro.distributions.continuous import MultivariateNormal -from causalprog.distribution.base import Distribution, SampleTranslator +from causalprog.backend.translation import Translation +from causalprog.distribution.base import Distribution +from causalprog.distribution.normal import Normal def test_different_backends(rng_key) -> None: @@ -22,11 +24,29 @@ def test_different_backends(rng_key) -> None: sample_size = (10, 5) distrax_normal = distrax.MultivariateNormalFullCovariance(mean, cov) - distrax_dist = Distribution(distrax_normal, SampleTranslator(rng_key="seed")) + distrax_dist = Distribution( + Translation( + backend_name="sample", + frontend_name="sample", + param_map={"seed": "rng_key"}, + ), + backend=distrax_normal, + label="Distrax normal", + ) distrax_samples = distrax_dist.sample(rng_key, sample_size) npyo_normal = MultivariateNormal(mean, cov) - npyo_dist = Distribution(npyo_normal, SampleTranslator(rng_key="key")) + npyo_dist = Distribution( + Translation( + backend_name="sample", frontend_name="sample", param_map={"key": "rng_key"} + ), + backend=npyo_normal, + label="NumPyro normal", + ) npyo_samples = npyo_dist.sample(rng_key, sample_size) + native_normal = Normal(mean, cov) + native_samples = native_normal.sample(rng_key, sample_size) + assert jnp.allclose(distrax_samples, npyo_samples) + assert jnp.allclose(distrax_samples, native_samples) diff --git a/tests/test_distributions/test_family.py b/tests/test_distributions/test_family.py index 72a29e7..f972350 100644 --- a/tests/test_distributions/test_family.py +++ b/tests/test_distributions/test_family.py @@ -1,8 +1,43 @@ -import distrax +import jax.numpy as jnp import pytest +from distrax import MultivariateNormalFullCovariance as Mvn -from causalprog.distribution.base import SampleTranslator +from causalprog.backend.translation import Translation +from causalprog.distribution.base import Distribution from causalprog.distribution.family import DistributionFamily +from causalprog.distribution.normal import Normal, NormalFamily + + +@pytest.mark.parametrize( + ("n_dim_std_normal"), + [pytest.param(1, id="1D normal"), pytest.param(3, id="3D normal")], + indirect=["n_dim_std_normal"], +) +def test_sampling_consistency(rng_key, n_dim_std_normal) -> None: + """""" + sample_shape = (5, 10) + normal_family = NormalFamily() + + via_family = normal_family.construct(**n_dim_std_normal) + via_standard_class = Normal(**n_dim_std_normal) + + family_samples = via_family.sample(rng_key, sample_shape) + standard_class_samples = via_standard_class.sample(rng_key, sample_shape) + + assert jnp.allclose(family_samples, standard_class_samples) + + +class DistraxNormal(Distribution): + def __init__(self, mean, cov): + super().__init__( + Translation( + backend_name="sample", + frontend_name="sample", + param_map={"seed": "rng_key"}, + ), + backend=Mvn(mean, cov), + label="Distrax normal", + ) @pytest.mark.parametrize( @@ -16,13 +51,14 @@ def test_builder_matches_backend(n_dim_std_normal) -> None: to building via the backend explicitly. """ - mnv = distrax.MultivariateNormalFullCovariance - - mnv_family = DistributionFamily(mnv, SampleTranslator(rng_key="seed")) - via_family = mnv_family.construct( + mnv_family = DistributionFamily( + DistraxNormal, + label="Distrax normal family", + ) + via_family = mnv_family.construct(**n_dim_std_normal) + via_backend = Mvn( loc=n_dim_std_normal["mean"], covariance_matrix=n_dim_std_normal["cov"] ) - via_backend = mnv(n_dim_std_normal["mean"], n_dim_std_normal["cov"]) - assert via_backend.kl_divergence(via_family.get_dist()) == pytest.approx(0.0) - assert via_family.get_dist().kl_divergence(via_backend) == pytest.approx(0.0) + assert via_backend.kl_divergence(via_family.dist) == pytest.approx(0.0) + assert via_family.dist.kl_divergence(via_backend) == pytest.approx(0.0) diff --git a/tests/test_translator.py b/tests/test_translator.py deleted file mode 100644 index 7197992..0000000 --- a/tests/test_translator.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Tests for the SampleCompatibility class.""" - -import re - -import pytest - -from causalprog.utils.translator import Translator - - -class _TranslatorForTesting(Translator): - @property - def _frontend_method(self) -> str: - """Name of the frontend method that the backend is to be translated into.""" - return "method" - - @property - def compulsory_frontend_args(self) -> set[str]: - """Arguments that are required by the frontend function.""" - return {"arg1", "arg2"} - - -class DummyClass: - """ - Stub class for testing. - - Intended use is to provide a variety of method call signatures, that can be used to - verify whether the ``SampleCompatibility`` class is correctly able to determine if - it will be able to call an underlying method without error. - """ - - @property - def prop(self) -> None: - """Properties are not callable.""" - return - - def __init__(self) -> None: - """Create an instance.""" - return - - def __str__(self) -> str: - """Display, in case the object appears in an error string.""" - return "DummyClass instance" - - def method(self, arg1: int, arg2: int = 0, kwarg1: int = 0) -> int: - """Take 1 compulsory and 2 optional arguments.""" - return arg1 + arg2 + kwarg1 - - -@pytest.fixture -def dummy_class_instance() -> DummyClass: - """Instance of the ``DummyClass`` to use in testing.""" - return DummyClass() - - -@pytest.mark.parametrize( - ("info", "expected_result"), - [ - pytest.param( - _TranslatorForTesting(backend_method="method_does_not_exist"), - AttributeError( - "DummyClass instance has no method 'method_does_not_exist'." - ), - id="Object does not have the backend method.", - ), - pytest.param( - _TranslatorForTesting(backend_method="prop"), - TypeError("'prop' attribute of DummyClass instance is not callable."), - id="Object backend method is not callable.", - ), - pytest.param( - _TranslatorForTesting(arg1="not_an_arg"), - TypeError("'method' does not take argument 'not_an_arg'"), - id="Backend does not take compulsory argument.", - ), - pytest.param( - _TranslatorForTesting("method", arg1="arg2", arg2="kwarg1"), - TypeError("'method' not provided compulsory arguments (missing arg1)"), - id="Backend cannot have unspecified compulsory arguments.", - ), - pytest.param( - _TranslatorForTesting(), - None, - id="Fall back on defaults.", - ), - pytest.param( - _TranslatorForTesting(backend_method="method", arg2="kwarg1"), - None, - id="Match args out-of-order.", - ), - ], -) -def test_validate_compatible( - info: _TranslatorForTesting, - dummy_class_instance: DummyClass, - expected_result: Exception | None, -) -> None: - """ - Test the validate_compatible method. - - Test that a SampleCompatibility instance correctly determines if a given method - of a given object is callable, with the information stored in the instance. - """ - if expected_result is not None: - with pytest.raises( - type(expected_result), match=re.escape(str(expected_result)) - ): - info.validate_compatible(dummy_class_instance) - else: - info.validate_compatible(dummy_class_instance) - - -@pytest.mark.parametrize( - ("translator", "input_kwargs", "expected_kwargs"), - [ - pytest.param( - _TranslatorForTesting(), - {"arg1": 0, "arg2": 1}, - {"arg1": 0, "arg2": 1}, - id="Args unchanged.", - ), - pytest.param( - _TranslatorForTesting(arg1="arg2", arg2="arg1"), - {"arg1": 0, "arg2": 1}, - {"arg1": 1, "arg2": 0}, - id="Order of args is swapped.", - ), - pytest.param( - _TranslatorForTesting(arg2="very_different_name"), - {"arg1": 0, "arg2": 1}, - {"arg1": 0, "very_different_name": 1}, - id="Backend names replaced where necessary.", - ), - ], -) -def test_translation( - translator: _TranslatorForTesting, - input_kwargs: dict[str, str], - expected_kwargs: dict[str, str], -) -> None: - """Test the mapping of (compatible) frontend args to backend args.""" - computed_output = translator.translate_args(**input_kwargs) - - assert computed_output == expected_kwargs