diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ac9eff143..600898979e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,6 +61,7 @@ jobs: tests/distributions/test_shape_utils.py tests/distributions/test_mixture.py tests/test_testing.py + tests/dispatch/test_jax.py - | tests/distributions/test_continuous.py diff --git a/pymc/dispatch/dispatch_jax.py b/pymc/dispatch/dispatch_jax.py new file mode 100644 index 0000000000..694c297363 --- /dev/null +++ b/pymc/dispatch/dispatch_jax.py @@ -0,0 +1,44 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp + +from pytensor.link.jax.dispatch import jax_funcify + +from pymc.distributions.continuous import TruncatedNormalRV + + +@jax_funcify.register(TruncatedNormalRV) +def jax_funcify_TruncatedNormalRV(op, **kwargs): + def trunc_normal_fn(key, size, mu, sigma, lower, upper): + rng_key = key["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + key["jax_state"] = rng_key + + if lower is None: + lower = -jnp.inf + if upper is None: + upper = jnp.inf + else: + new_lower, new_upper = (lower - mu) / sigma, (upper - mu) / sigma + + if size is None: + size = jnp.broadcast_arrays(jnp.array(mu), jnp.array(sigma))[0].shape + + res = jax.random.truncated_normal(key["jax_state"], new_lower, new_upper, shape=size) + res = res * sigma + mu + + return key, res + + return trunc_normal_fn diff --git a/tests/dispatch/test_jax.py b/tests/dispatch/test_jax.py new file mode 100644 index 0000000000..0ff536eec7 --- /dev/null +++ b/tests/dispatch/test_jax.py @@ -0,0 +1,56 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from pytensor import function + +import pymc as pm + +from pymc.dispatch import dispatch_jax # noqa: F401 + +jax = pytest.importorskip("jax") + + +@pytest.mark.parametrize("sigma", [0.02, 5]) +def test_jax_TruncatedNormal(sigma): + with pm.Model() as m: + lower = 5 + upper = 8 + mu = 6 + + a = pm.TruncatedNormal( + "a", mu, sigma, lower=lower, upper=upper, rng=np.random.default_rng(seed=123) + ) + + f_jax = function( + [], + [ + pm.TruncatedNormal( + "b", + mu, + sigma, + lower=lower, + upper=upper, + rng=np.random.default_rng(seed=123), + ) + ], + mode="JAX", + ) + res = f_jax() + + draws = pm.draw(a, draws=100, mode="JAX") + + assert jax.numpy.all((draws >= lower) & (draws <= upper)) + assert jax.numpy.all((res[0] >= lower) & (res[0] <= upper))