From d3f2509e99cebf8b7e8cf256f7250a0ff12efbc1 Mon Sep 17 00:00:00 2001 From: JP Chen Date: Tue, 12 Oct 2021 12:05:35 -0700 Subject: [PATCH] Add unit distribution for adding to log prob Summary: Addresses [#1041](https://github.com/facebookresearch/beanmachine/issues/1041). Imported the Unit Dist (with some minor modifications) from pyro. This will allow users to add terms to the model density: ``` bm.random_variable def increment_log_prob(): val = Normal(0., 1.).log_prob(1.) return Unit(val) ``` In the future we can wrap this with a `factor` statement. Reviewed By: neerajprad Differential Revision: D31516303 fbshipit-source-id: 171e08797cd0367dd985f13b149c9c766c6fddef --- src/beanmachine/ppl/distribution/__init__.py | 3 +- src/beanmachine/ppl/distribution/unit.py | 66 ++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 src/beanmachine/ppl/distribution/unit.py diff --git a/src/beanmachine/ppl/distribution/__init__.py b/src/beanmachine/ppl/distribution/__init__.py index 1570def183..567085ae2e 100644 --- a/src/beanmachine/ppl/distribution/__init__.py +++ b/src/beanmachine/ppl/distribution/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. from beanmachine.ppl.distribution.flat import Flat +from beanmachine.ppl.distribution.unit import Unit -__all__ = ["Flat"] +__all__ = ["Flat", "Unit"] diff --git a/src/beanmachine/ppl/distribution/unit.py b/src/beanmachine/ppl/distribution/unit.py new file mode 100644 index 0000000000..449ad25f70 --- /dev/null +++ b/src/beanmachine/ppl/distribution/unit.py @@ -0,0 +1,66 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.distributions import constraints + + +def broadcast_shape(*shapes, **kwargs): + """ + Similar to ``np.broadcast()`` but for shapes. + Equivalent to ``np.broadcast(*map(np.empty, shapes)).shape``. + :param tuple shapes: shapes of tensors. + :param bool strict: whether to use extend-but-not-resize broadcasting. + :returns: broadcasted shape + :rtype: tuple + :raises: ValueError + """ + strict = kwargs.pop("strict", False) + reversed_shape = [] + for shape in shapes: + for i, size in enumerate(reversed(shape)): + if i >= len(reversed_shape): + reversed_shape.append(size) + elif reversed_shape[i] == 1 and not strict: + reversed_shape[i] = size + elif reversed_shape[i] != size and (size != 1 or strict): + raise ValueError( + "shape mismatch: objects cannot be broadcast to a single shape: {}".format( + " vs ".join(map(str, shapes)) + ) + ) + return tuple(reversed(reversed_shape)) + + +class Unit(torch.distributions.Distribution): + """ + Trivial nonnormalized distribution representing the unit type. + + The unit type has a single value with no data, i.e. ``value.numel() == 0``. + + This is used for :func:`pyro.factor` statements. + """ + + arg_constraints = {"log_factor": constraints.real} + support = constraints.real + + def __init__(self, log_factor, validate_args=None): + log_factor = torch.as_tensor(log_factor) + batch_shape = log_factor.shape + event_shape = torch.Size((0,)) # This satisfies .numel() == 0. + self.log_factor = log_factor + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Unit, _instance) + new.log_factor = self.log_factor.expand(batch_shape) + super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): # noqa: B008 + return self.log_factor.new_empty(sample_shape) + + def log_prob(self, value): + shape = broadcast_shape(self.batch_shape, value.shape[:-1]) + return self.log_factor.expand(shape)