diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index a413224a9..07d108539 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -81,12 +81,16 @@ from numpyro.distributions.transforms import biject_to from numpyro.distributions.truncated import ( LeftTruncatedDistribution, + LeftTruncatedGamma, RightTruncatedDistribution, + RightTruncatedGamma, TruncatedCauchy, TruncatedDistribution, + TruncatedGamma, TruncatedNormal, TruncatedPolyaGamma, TwoSidedTruncatedDistribution, + TwoSidedTruncatedGamma, ) from . import constraints, transforms @@ -141,6 +145,7 @@ "MultinomialLogits", "MultinomialProbs", "MultivariateNormal", + "LeftTruncatedGamma", "LowRankMultivariateNormal", "Normal", "NegativeBinomialProbs", @@ -152,6 +157,7 @@ "ProjectedNormal", "PRNGIdentity", "RightTruncatedDistribution", + "RightTruncatedGamma", "SineBivariateVonMises", "SineSkewed", "SoftLaplace", @@ -159,9 +165,11 @@ "TransformedDistribution", "TruncatedCauchy", "TruncatedDistribution", + "TruncatedGamma", "TruncatedNormal", "TruncatedPolyaGamma", "TwoSidedTruncatedDistribution", + "TwoSidedTruncatedGamma", "Uniform", "Unit", "VonMises", diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 75b727872..1194c9eb9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -31,7 +31,16 @@ import jax.numpy as jnp import jax.random as random from jax.scipy.linalg import cho_solve, solve_triangular -from jax.scipy.special import betainc, expit, gammaln, logit, multigammaln, ndtr, ndtri +from jax.scipy.special import ( + betainc, + expit, + gammainc, + gammaln, + logit, + multigammaln, + ndtr, + ndtri, +) from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution, TransformedDistribution @@ -282,6 +291,15 @@ def mean(self): def variance(self): return self.concentration / jnp.power(self.rate, 2) + def cdf(self, value): + return gammainc(self.concentration, value * self.rate) + + def icdf(self, q): + # https://github.com/pyro-ppl/numpyro/issues/969 + from numpyro.distributions.util import gammaincinv + + return gammaincinv(self.concentration, q) / self.rate + class Chi2(Gamma): arg_constraints = {"df": constraints.positive} diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 26333adc9..e0331f073 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -4,12 +4,13 @@ from jax import lax import jax.numpy as jnp import jax.random as random -from jax.scipy.special import logsumexp +from jax.scipy.special import gammainc, logsumexp from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.continuous import ( Cauchy, + Gamma, Laplace, Logistic, Normal, @@ -411,3 +412,327 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, params): return cls(batch_shape=aux_data) + + +def TruncatedGamma(base_gamma, low=None, high=None, validate_args=None): + """ + A function to generate a truncated gamma distribution. + + :param base_gamma: The base Gamma distribution to be truncated. + :param low: the value which is used to truncate the base distribution from below. + Setting this parameter to None to not truncate from below. + :param high: the value which is used to truncate the base distribution from above. + Setting this parameter to None to not truncate from above. + """ + if high is None: + if low is None: + return base_gamma + else: + return LeftTruncatedGamma(base_gamma, low=low, validate_args=validate_args) + elif low is None: + return RightTruncatedGamma(base_gamma, high=high, validate_args=validate_args) + else: + return TwoSidedTruncatedGamma( + base_gamma, low=low, high=high, validate_args=validate_args + ) + + +class LeftTruncatedGamma(Distribution): + arg_constraints = {"low": constraints.positive} + reparametrized_params = ["low"] + + def __init__(self, base_gamma, low, validate_args=None): + assert isinstance(base_gamma, Gamma) + batch_shape = lax.broadcast_shapes(base_gamma.batch_shape, jnp.shape(low)) + self.base_gamma = tree_map( + lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma + ) + (self.low,) = promote_shapes(low, shape=batch_shape) + self._support = constraints.greater_than(low) + super().__init__(batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + return self.icdf(u) + + @validate_sample + def log_prob(self, value): + lprob = self.base_gamma.log_prob(value) + lscale = self.base_gamma.cdf(self.low) + return lprob - jnp.log(1.0 - lscale) + + def _scale_moment(self, t): + assert t > -self.base_gamma.concentration + s_lscale = gammainc( + self.base_gamma.concentration + t, self.low * self.base_gamma.rate + ) + lscale = self.base_gamma.cdf(self.low) + return (1.0 - s_lscale) / (1.0 - lscale) + + @property + def mean(self): + base_mean = self.base_gamma.mean + rescale = self._scale_moment(1.0) + return rescale * base_mean + + @property + def variance(self): + # compute E[X]^2 + fst_m_sq = jnp.power(self.mean, 2.0) + + # compute E[X^2] + base_sec_mt = ( + (self.base_gamma.concentration + 1) + * self.base_gamma.concentration + * jnp.power(self.base_gamma.rate, -2.0) + ) + rescale = self._scale_moment(2.0) + sec_mt = base_sec_mt * rescale + + # V[X] = E[X^2] - E[X]^2 + return sec_mt - fst_m_sq + + def tree_flatten(self): + base_flatten, base_aux = self.base_gamma.tree_flatten() + if isinstance(self._support.lower_bound, (int, float)): + return base_flatten, ( + type(self.base_gamma), + base_aux, + self._support.lower_bound, + ) + else: + return (base_flatten, self.low), (type(self.base_gamma), base_aux) + + @classmethod + def tree_unflatten(cls, aux_data, params): + if len(aux_data) == 2: + base_flatten, low = params + base_cls, base_aux = aux_data + else: + base_flatten = params + base_cls, base_aux, low = aux_data + base_gamma = Gamma.tree_unflatten(base_aux, base_flatten) + return cls(base_gamma, low=low) + + @validate_sample + def cdf(self, value): + gcdf = self.base_gamma.cdf(value) + lscale = self.base_gamma.cdf(self.low) + return (gcdf - lscale) / (1.0 - lscale) + + def icdf(self, q): + lscale = self.base_gamma.cdf(self.low) + q = q * (1.0 - lscale) + lscale + return self.base_gamma.icdf(q) + + +class RightTruncatedGamma(Distribution): + arg_constraints = {"high": constraints.positive} + reparametrized_params = ["high"] + + def __init__(self, base_gamma, high, validate_args=None): + assert isinstance(base_gamma, Gamma) + batch_shape = lax.broadcast_shapes(base_gamma.batch_shape, jnp.shape(high)) + self.base_gamma = tree_map( + lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma + ) + (self.high,) = promote_shapes(high, shape=batch_shape) + self._support = constraints.interval(0.0, high) + super().__init__(batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + return self.icdf(u) + + @validate_sample + def log_prob(self, value): + lprob = self.base_gamma.log_prob(value) + hscale = self.base_gamma.cdf(self.high) + return lprob - jnp.log(hscale) + + def _scale_moment(self, t): + assert t > -self.base_gamma.concentration + s_hscale = gammainc( + self.base_gamma.concentration + t, self.high * self.base_gamma.rate + ) + hscale = self.base_gamma.cdf(self.high) + return s_hscale / hscale + + @property + def mean(self): + base_mean = self.base_gamma.mean + rescale = self._scale_moment(1.0) + return rescale * base_mean + + @property + def variance(self): + # compute E[X]^2 + fst_m_sq = jnp.power(self.mean, 2.0) + + # compute E[X^2] + base_sec_mt = ( + (self.base_gamma.concentration + 1) + * self.base_gamma.concentration + * jnp.power(self.base_gamma.rate, -2.0) + ) + rescale = self._scale_moment(2.0) + sec_mt = base_sec_mt * rescale + + # V[X] = E[X^2] - E[X]^2 + return sec_mt - fst_m_sq + + def tree_flatten(self): + base_flatten, base_aux = self.base_gamma.tree_flatten() + if isinstance(self._support.upper_bound, (int, float)): + return base_flatten, ( + type(self.base_gamma), + base_aux, + self._support.upper_bound, + ) + else: + return (base_flatten, self.high), (type(self.base_gamma), base_aux) + + @classmethod + def tree_unflatten(cls, aux_data, params): + if len(aux_data) == 2: + base_flatten, high = params + base_cls, base_aux = aux_data + else: + base_flatten = params + base_cls, base_aux, high = aux_data + base_gamma = Gamma.tree_unflatten(base_aux, base_flatten) + return cls(base_gamma, high=high) + + @validate_sample + def cdf(self, value): + gcdf = self.base_gamma.cdf(value) + hscale = self.base_gamma.cdf(self.high) + return gcdf / hscale + + def icdf(self, q): + hscale = self.base_gamma.cdf(self.high) + q = q * hscale + return self.base_gamma.icdf(q) + + +class TwoSidedTruncatedGamma(Distribution): + arg_constraints = { + "low": constraints.positive, + "high": constraints.dependent, + } + reparametrized_params = ["low", "high"] + + def __init__(self, base_gamma, low, high, validate_args=None): + assert isinstance(base_gamma, Gamma) + batch_shape = lax.broadcast_shapes( + base_gamma.batch_shape, jnp.shape(low), jnp.shape(high) + ) + self.base_gamma = tree_map( + lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma + ) + (self.low,) = promote_shapes(low, shape=batch_shape) + (self.high,) = promote_shapes(high, shape=batch_shape) + self._support = constraints.interval(low, high) + super().__init__(batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + return self.icdf(u) + + @validate_sample + def log_prob(self, value): + lprob = self.base_gamma.log_prob(value) + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + return lprob - jnp.log(hscale - lscale) + + def _scale_moment(self, t): + assert t > -self.base_gamma.concentration + s_lscale = gammainc( + self.base_gamma.concentration + t, self.low * self.base_gamma.rate + ) + s_hscale = gammainc( + self.base_gamma.concentration + t, self.high * self.base_gamma.rate + ) + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + return (s_hscale - s_lscale) / (hscale - lscale) + + @property + def mean(self): + base_mean = self.base_gamma.mean + rescale = self._scale_moment(1.0) + return rescale * base_mean + + @property + def variance(self): + # compute E[X]^2 + fst_m_sq = jnp.power(self.mean, 2.0) + + # compute E[X^2] + base_sec_mt = ( + (self.base_gamma.concentration + 1) + * self.base_gamma.concentration + * jnp.power(self.base_gamma.rate, -2.0) + ) + rescale = self._scale_moment(2.0) + sec_mt = base_sec_mt * rescale + + # V[X] = E[X^2] - E[X]^2 + return sec_mt - fst_m_sq + + def tree_flatten(self): + base_flatten, base_aux = self.base_gamma.tree_flatten() + if isinstance(self._support.lower_bound, (int, float)) and isinstance( + self._support.upper_bound, (int, float) + ): + return base_flatten, ( + type(self.base_gamma), + base_aux, + self._support.lower_bound, + self._support.upper_bound, + ) + else: + return (base_flatten, self.low, self.high), ( + type(self.base_gamma), + base_aux, + ) + + @classmethod + def tree_unflatten(cls, aux_data, params): + if len(aux_data) == 2: + base_flatten, low, high = params + base_cls, base_aux = aux_data + else: + base_flatten = params + base_cls, base_aux, low, high = aux_data + base_gamma = Gamma.tree_unflatten(base_aux, base_flatten) + return cls(base_gamma, low=low, high=high) + + @validate_sample + def cdf(self, value): + gcdf = self.base_gamma.cdf(value) + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + return (gcdf - lscale) / (hscale - lscale) + + def icdf(self, q): + lscale = self.base_gamma.cdf(self.low) + hscale = self.base_gamma.cdf(self.high) + q = q * (hscale - lscale) + lscale + return self.base_gamma.icdf(q) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c24c6bd96..57ca9757d 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -560,6 +560,20 @@ def is_prng_key(key): return False +def gammaincinv(a, p): + # until jax/lax has direct implementation we'll need to rely on tfp + # https://github.com/pyro-ppl/numpyro/issues/969 + try: + import tensorflow_probability as tfpm + except ImportError as e: + raise ImportError( + "To use gammaincinv, please install TensorFlow Probability. It can be" + " installed with `pip install tensorflow_probability`" + ) from e + + return tfpm.substrates.jax.math.igammainv(a, p) + + # The is sourced from: torch.distributions.util.py # # Copyright (c) 2016- Facebook, Inc (Adam Paszke) diff --git a/test/test_distributions.py b/test/test_distributions.py index 706d8420f..879aea644 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -272,6 +272,7 @@ def get_sp_dist(jax_dist): T(dist.Laplace, 0.0, 1.0), T(dist.Laplace, 0.5, jnp.array([1.0, 2.5])), T(dist.Laplace, jnp.array([1.0, -0.5]), jnp.array([2.3, 3.0])), + T(dist.LeftTruncatedGamma, dist.Gamma(2.0, 2.0), 1.0), T(dist.LKJ, 2, 0.5, "onion"), T(dist.LKJ, 5, jnp.array([0.5, 1.0, 2.0]), "cvine"), T(dist.LKJCholesky, 2, 0.5, "onion"), @@ -346,6 +347,7 @@ def get_sp_dist(jax_dist): T(dist.Pareto, 1.0, 2.0), T(dist.Pareto, jnp.array([1.0, 0.5]), jnp.array([0.3, 2.0])), T(dist.Pareto, jnp.array([[1.0], [3.0]]), jnp.array([1.0, 0.5])), + T(dist.RightTruncatedGamma, dist.Gamma(2.0, 2.0), 10.0), T(dist.SoftLaplace, 1.0, 1.0), T(dist.SoftLaplace, jnp.array([-1.0, 50.0]), jnp.array([4.0, 100.0])), T(dist.StudentT, 1.0, 1.0, 0.5), @@ -379,6 +381,7 @@ def get_sp_dist(jax_dist): jnp.array([-2.0, 2.0]), ), T(dist.TwoSidedTruncatedDistribution, dist.Laplace(0.0, 1.0), -2.0, 3.0), + T(dist.TwoSidedTruncatedGamma, dist.Gamma(2.0, 2.0), 0.5, 10.0), T(dist.Uniform, 0.0, 2.0), T(dist.Uniform, 1.0, jnp.array([2.0, 3.0])), T(dist.Uniform, jnp.array([0.0, 0.0]), jnp.array([[2.0], [3.0]])), @@ -903,6 +906,42 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): ) assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) return + elif isinstance( + jax_dist, + ( + dist.LeftTruncatedGamma, + dist.RightTruncatedGamma, + dist.TwoSidedTruncatedGamma, + ), + ): + # params = [base_gamma[concentration, rate], low, high] + if isinstance(jax_dist, dist.LeftTruncatedGamma): + conc, rate, low = ( + params[0].concentration, + params[0].rate, + params[1], + ) + high = np.inf + elif isinstance(jax_dist, dist.RightTruncatedGamma): + conc, rate, high = ( + params[0].concentration, + params[0].rate, + params[1], + ) + low = -np.inf + else: + conc, rate, low, high = ( + params[0].concentration, + params[0].rate, + params[1], + params[2], + ) + sp_dist = get_sp_dist(dist.Gamma)(conc, rate) + expected = sp_dist.logpdf(samples) - jnp.log( + sp_dist.cdf(high) - sp_dist.cdf(low) + ) + assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) + return pytest.skip("no corresponding scipy distn.") if _is_batched_multivariate(jax_dist): pytest.skip("batching not allowed in multivariate distns.") @@ -1358,6 +1397,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): and dist_args[i] == "base_dist" ): continue + if jax_dist is dist.TwoSidedTruncatedGamma and dist_args[i] == "base_gamma": + continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue if (