Skip to content

Finalise backend-agnosticity of Distributions #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
38c9c18
Pull backend_agnostic class
willGraham01 Mar 20, 2025
61eb331
Pull backend_agnostic tests
willGraham01 Mar 20, 2025
59f6af3
Test positive and negative case
willGraham01 Mar 20, 2025
2e91b2a
Tidy docstrings and methods
willGraham01 Mar 20, 2025
d2f40d6
Pull signature conversions from branch
willGraham01 Mar 20, 2025
9b3728a
Pull tests for converting signatures
willGraham01 Mar 20, 2025
1958d98
Hide hidden variable import behind typehint
willGraham01 Mar 20, 2025
89d08ab
Tidy vargs and kwargs checker function
willGraham01 Mar 20, 2025
aa7d06f
Refactor _check tests
willGraham01 Mar 20, 2025
1fd75f0
Tidy convert_signature docstring
willGraham01 Mar 20, 2025
44a8504
Parameter naming and docstrings for _signature_can_be_cast
willGraham01 Mar 20, 2025
dc283cf
Refactor _signature_can_be_cast tests
willGraham01 Mar 20, 2025
1dddcc7
Use fixtures for convert_signature test
willGraham01 Mar 20, 2025
5b28001
remove outdated comment
willGraham01 Mar 20, 2025
20a6dc6
Merge branch 'wgraham/backend-agnostic-abc' into wgraham/backend-tran…
willGraham01 Mar 20, 2025
6aedfcd
Write identity map
willGraham01 Mar 20, 2025
2e4def5
convert_signature now returns a mapping
willGraham01 Mar 20, 2025
24e7e05
Hidden method for calling the backend
willGraham01 Mar 20, 2025
f2aba0b
Split Translation out just to compartmentalise
willGraham01 Mar 21, 2025
f68ee90
Tests for Translator itself
willGraham01 Mar 21, 2025
6a495a1
Cannot multiple inherit from two classes that both define __slots__
willGraham01 Mar 21, 2025
f7732a7
Rework Distribution to be backend-agnostic and use translators
willGraham01 Mar 21, 2025
a0ec6e1
Fix up distribution families, which now just point to backend-agnosti…
willGraham01 Mar 21, 2025
a93071b
Remove now-defunct translator
willGraham01 Mar 21, 2025
689b798
Fix some outdated docstrings
willGraham01 Mar 21, 2025
5c27b0b
No longer need to test a deleted file
willGraham01 Mar 21, 2025
a5d1988
Pull signature conversions from branch
willGraham01 Mar 20, 2025
9fe5f3d
Pull tests for converting signatures
willGraham01 Mar 20, 2025
5b2dbae
Hide hidden variable import behind typehint
willGraham01 Mar 20, 2025
efc9315
Tidy vargs and kwargs checker function
willGraham01 Mar 20, 2025
51eaa8b
Refactor _check tests
willGraham01 Mar 20, 2025
7a5d1b6
Tidy convert_signature docstring
willGraham01 Mar 20, 2025
5f62aba
Parameter naming and docstrings for _signature_can_be_cast
willGraham01 Mar 20, 2025
e4b99a8
Refactor _signature_can_be_cast tests
willGraham01 Mar 20, 2025
7df9988
Use fixtures for convert_signature test
willGraham01 Mar 20, 2025
775489f
remove outdated comment
willGraham01 Mar 20, 2025
a463a14
Ruff correcting ruff...
willGraham01 Mar 21, 2025
0972b3a
Merge branch 'wgraham/signature-converting' into wgraham/distribution…
willGraham01 Mar 21, 2025
7cd3806
Remove utils import that has now disappeared
willGraham01 Mar 21, 2025
e0d22b9
Apply suggestions from code review
willGraham01 Mar 24, 2025
89ead04
Merge branch 'main' into wgraham/fix-conflicts-with-30
willGraham01 Mar 24, 2025
61748b4
Merge branch 'main' into wgraham/distributions-are-backend-agnostic
willGraham01 Mar 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/causalprog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""causalprog package."""

from . import algorithms, distribution, graph, utils
from . import algorithms, distribution, graph
from ._version import __version__
1 change: 0 additions & 1 deletion src/causalprog/_abc/backend_agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class BackendAgnostic(ABC, Generic[Backend]):
calls to the ``_backend_obj`` as necessary.
"""

__slots__ = ("_backend_obj",)
_backend_obj: Backend

def __getattr__(self, name: str) -> Any: # noqa: ANN401
Expand Down
1 change: 0 additions & 1 deletion src/causalprog/_abc/labelled.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class Labelled(ABC):
``label`` property of the class.
"""

__slots__ = ("_label",)
_label: str

@property
Expand Down
18 changes: 9 additions & 9 deletions src/causalprog/backend/_convert_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ def convert_signature(
new_signature: Signature,
old_to_new_names: ParamNameMap,
give_static_value: StaticValues,
) -> Callable[..., ReturnType]:
) -> Callable[..., tuple[Any, Any]]:
"""
Convert the call signature of a function ``fn`` to that of ``new_signature``.

This function effectively allows ``fn`` to be called with ``new_signature``. It
returns a new ``Callable`` that uses the ``new_signature``, and returns the result
of ``fn`` after translating the ``new_signature`` back into that of ``fn`` and
making an appropriate call.
returns a new ``Callable`` (denoted ``g``) that maps the parameters of
``new_signature`` to the (corresponding) parameters of ``fn``. As such, ``fn``
composed with ``g`` allows for calling ``fn`` with the ``new_signature``.

Converting signatures into each other is, in general, not possible. However under
certain assumptions and conventions, it can be done. To that end, the following
Expand Down Expand Up @@ -236,7 +236,7 @@ def convert_signature(
information provided.

Returns:
Callable: Callable representing ``fn`` with ``new_signature``.
Callable: Callable mapping parameters in ``new_signature`` to those in ``fn``.

See Also:
_check_variable_length_params: Validation of number of variable-length
Expand Down Expand Up @@ -270,7 +270,7 @@ def convert_signature(
static_kwargs = give_static_value.pop(fn_kwargs_param)
give_static_value = dict(give_static_value, **static_kwargs)

def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType:
def new_sig_to_fn_sig(*args: Any, **kwargs: Any) -> tuple[list, dict[str, Any]]: # noqa: ANN401
bound = new_signature.bind(*args, **kwargs)
bound.apply_defaults()

Expand All @@ -292,7 +292,7 @@ def fn_with_new_signature(*args: tuple, **kwargs: dict[str, Any]) -> ReturnType:
fn_args = [fn_kwargs.pop(p_name) for p_name in fn_posix_args]
if fn_vargs_param:
fn_args.extend(fn_kwargs.pop(fn_vargs_param, []))
# Now we can call fn
return fn(*fn_args, **fn_kwargs)
# Arguments are now mapped
return fn_args, fn_kwargs

return fn_with_new_signature
return new_sig_to_fn_sig
49 changes: 49 additions & 0 deletions src/causalprog/backend/translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Data structure for storing information mapping backend to frontend."""

from dataclasses import dataclass, field

from ._typing import ParamNameMap, StaticValues


@dataclass
class Translation:
"""
Helper class for mapping frontend signatures to backend signatures.

Attributes:
backend_name (str): Name of the backend method that is being translated into
a frontend method.
frontend_name (str): Name of the frontend method that the backend method will
be used as.
param_map (ParamNameMap): See ``old_to_new_names`` argument to
``causalprog.backend._convert_signature``.
frozen_args (StaticValues): See ``give_static_value`` argument to
``causalprog.backend._convert_signature``.

"""

backend_name: str
frontend_name: str
param_map: ParamNameMap
frozen_args: StaticValues = field(default_factory=dict)

def __post_init__(self) -> None:
if not isinstance(self.backend_name, str):
msg = f"backend_name '{self.backend_name}' is not a string."
raise TypeError(msg)
if not isinstance(self.frontend_name, str):
msg = f"frontend_name '{self.frontend_name}' is not a string."
raise TypeError(msg)

self.frozen_args = dict(self.frozen_args)
self.param_map = dict(self.param_map)

if not all(
isinstance(key, str) and isinstance(value, str)
for key, value in self.param_map.items()
):
msg = "Parameter map must map str -> str."
raise TypeError(msg)
if not all(isinstance(key, str) for key in self.frozen_args):
msg = "Frozen args must be specified by name (str)."
raise TypeError(msg)
101 changes: 101 additions & 0 deletions src/causalprog/backend/translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Translating backend object syntax to frontend syntax."""

from collections.abc import Callable
from inspect import signature
from typing import Any

from causalprog._abc.backend_agnostic import Backend, BackendAgnostic

from ._convert_signature import convert_signature
from .translation import Translation


class Translator(BackendAgnostic[Backend]):
"""
Translates the methods of a backend object into frontend syntax.

A ``Translator`` acts as an intermediary between a backend that is supplied by the
user and the frontend syntax that ``causalprog`` relies on. The default backend of
``causalprog`` uses a syntax compatible with ``jax``.

Other backends may not conform to the syntax that ``causalprog`` expects, but
nonetheless may provide the functionality that it requires. A ``Translator`` is able
to make calls to (the relevant methods of) this backend, whilst still conforming to
the frontend syntax of ``causalprog``.

As an example, suppose that we have a frontend class ``C`` that needs to provide a
method ``do_this``. ``causalprog`` expects ``C`` to provide the functionality
of ``do_this`` via one of its methods, ``C.do_this(*c_args, **c_kwargs)``.
Now suppose that a class ``D`` from a different, external package might also
provides the functionality of ``do_this``, but it is done by calling
``D.do_this_different(*d_args, **d_kwargs)``, where there is some mapping
``m: *c_args, **c_kwargs -> *d_args, **d_kwargs``. In such a case, ``causalprog``
needs to use a ``Translator`` ``T``, rather than ``D`` directly, where

``T.do_this(*c_args, **c_kwargs) = D.do_this_different(m(*c_args, **c_kwargs))``.
"""

frontend_to_native_names: dict[str, str]
translations: dict[str, Callable]

@staticmethod
def identity(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any]]: # noqa: ANN401
"""Identity map on positional and keyword arguments."""
return args, kwargs

def __init__(
self,
*translations: Translation,
backend: Backend,
) -> None:
"""
Translate a backend object into a frontend-compatible object.

Args:
backend (Backend): Backend object that must be translated to support
frontend syntax.
*translations (Translation): ``Translation``s that map the methods of
``backend`` to the (signatures of the) methods that the
``_frontend_provides``.

"""
super().__init__(backend=backend)

self.translations = {}
self.frontend_to_native_names = {name: name for name in self._frontend_provides}
for t in translations:
native_name = t.backend_name
translated_name = t.frontend_name
native_method = getattr(self._backend_obj, native_name)
target_method = getattr(self, translated_name)
target_signature = signature(target_method)

self.translations[translated_name] = convert_signature(
native_method, target_signature, t.param_map, t.frozen_args
)
self.frontend_to_native_names[translated_name] = native_name

# Methods without explicit translations are assumed to be the identity map,
# provided they exist on the backend object.
for method in self._frontend_provides:
method_has_translation = method in self.translations
backend_has_method = hasattr(self._backend_obj, method)
if not (method_has_translation or backend_has_method):
msg = (
f"No translation provided for {method}, "
"which the backend is lacking."
)
raise AttributeError(msg)
if not method_has_translation:
# Assume the identity mapping to the backend method, otherwise.
self.translations[method] = self.identity

self.validate()

def _call_backend_with(self, method: str, *args: Any, **kwargs: Any) -> Any: # noqa:ANN401
"""Translate arguments, then call the backend."""
backend_method = getattr(
self._backend_obj, self.frontend_to_native_names[method]
)
backend_args, backend_kwargs = self.translations[method](*args, **kwargs)
return backend_method(*backend_args, **backend_kwargs)
97 changes: 42 additions & 55 deletions src/causalprog/distribution/base.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,45 @@
"""Base class for backend-agnostic distributions."""

from collections.abc import Callable
from typing import Generic, TypeVar
from typing import TypeVar

from numpy.typing import ArrayLike

from causalprog._abc.labelled import Labelled
from causalprog.utils.translator import Translator
from causalprog.backend.translation import Translation
from causalprog.backend.translator import Translator

SupportsRNG = TypeVar("SupportsRNG")
SupportsSampling = TypeVar("SupportsSampling", bound=object)


class SampleTranslator(Translator):
"""
Translate methods for sampling from distributions.

The ``Distribution`` class provides a ``sample`` method, that takes ``rng_key`` and
``sample_shape`` as its arguments. Instances of this class transform the these
arguments to those that a backend distribution expects.
"""

@property
def _frontend_method(self) -> str:
return "sample"

@property
def compulsory_frontend_args(self) -> set[str]:
"""Arguments that are required by the frontend function."""
return {"rng_key", "sample_shape"}


class Distribution(Generic[SupportsSampling], Labelled):
class Distribution(Translator[SupportsSampling], Labelled):
"""A (backend-agnostic) distribution that can be sampled from."""

_dist: SupportsSampling
_backend_translator: SampleTranslator
@property
def _frontend_provides(self) -> tuple[str, ...]:
return ("sample",)

@property
def _sample(self) -> Callable[..., ArrayLike]:
"""Method for drawing samples from the backend object."""
return getattr(self._dist, self._backend_translator.backend_method)
def dist(self) -> SupportsSampling:
"""Return the object representing the distribution."""
return self.get_backend()

def __init__(
self,
backend_distribution: SupportsSampling,
backend_translator: SampleTranslator | None = None,
*,
label: str = "Distribution",
self, *translations: Translation, backend: SupportsSampling, label: str
) -> None:
"""
Create a new Distribution.
Create a new distribution, with a given backend.

Args:
backend_distribution (SupportsSampling): Backend object that supports
drawing random samples.
backend_translator (SampleTranslator): Translator object mapping backend
sampling function to frontend arguments.
*translations (Translation): Information for mapping the methods of the
backend object to the frontend methods provided by this class. See
``causalprog.backend.Translator`` for more details.
backend (SupportsSampling): Backend object that represents the distribution.
label (str): Name or label to attach to the distribution.

"""
super().__init__(label=label)

self._dist = backend_distribution

# Setup sampling calls, and perform one-time check for compatibility
self._backend_translator = (
backend_translator if backend_translator is not None else SampleTranslator()
)
self._backend_translator.validate_compatible(backend_distribution)

def get_dist(self) -> SupportsSampling:
"""Access to the backend distribution."""
return self._dist
Labelled.__init__(self, label=label)
Translator.__init__(self, *translations, backend=backend)

def sample(self, rng_key: SupportsRNG, sample_shape: ArrayLike = ()) -> ArrayLike:
"""
Expand All @@ -85,7 +53,26 @@ def sample(self, rng_key: SupportsRNG, sample_shape: ArrayLike = ()) -> ArrayLik
ArrayLike: Randomly-drawn samples from the distribution.

"""
args_to_backend = self._backend_translator.translate_args(
rng_key=rng_key, sample_shape=sample_shape
)
return self._sample(**args_to_backend)
return self._call_backend_with("sample", rng_key, sample_shape)


class NativeDistribution(Distribution[SupportsSampling]):
"""
A distribution that uses our native backend.

These distributions do not require translations, since the backend objects
they use conform to our frontend syntax by design.
"""

def __init__(self, *, backend: SupportsSampling, label: str) -> None:
"""
Create a new distribution, using a native backend.

Args:
backend (SupportsSampling): Backend object that represents the distribution.
Must be a native backend object; that is a distribution provided by the
``causalprog`` package.
label (str): Name or label to attach to the distribution.

"""
super().__init__(backend=backend, label=label)
Loading