Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Add unit distribution for adding to log prob #1051

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/beanmachine/ppl/distribution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from beanmachine.ppl.distribution.flat import Flat
from beanmachine.ppl.distribution.unit import Unit


__all__ = ["Flat"]
__all__ = ["Flat", "Unit"]
66 changes: 66 additions & 0 deletions src/beanmachine/ppl/distribution/unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints


def broadcast_shape(*shapes, **kwargs):
"""
Similar to ``np.broadcast()`` but for shapes.
Equivalent to ``np.broadcast(*map(np.empty, shapes)).shape``.
:param tuple shapes: shapes of tensors.
:param bool strict: whether to use extend-but-not-resize broadcasting.
:returns: broadcasted shape
:rtype: tuple
:raises: ValueError
"""
strict = kwargs.pop("strict", False)
reversed_shape = []
for shape in shapes:
for i, size in enumerate(reversed(shape)):
if i >= len(reversed_shape):
reversed_shape.append(size)
elif reversed_shape[i] == 1 and not strict:
reversed_shape[i] = size
elif reversed_shape[i] != size and (size != 1 or strict):
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape: {}".format(
" vs ".join(map(str, shapes))
)
)
return tuple(reversed(reversed_shape))


class Unit(torch.distributions.Distribution):
"""
Trivial nonnormalized distribution representing the unit type.

The unit type has a single value with no data, i.e. ``value.numel() == 0``.

This is used for :func:`pyro.factor` statements.
"""

arg_constraints = {"log_factor": constraints.real}
support = constraints.real

def __init__(self, log_factor, validate_args=None):
log_factor = torch.as_tensor(log_factor)
batch_shape = log_factor.shape
event_shape = torch.Size((0,)) # This satisfies .numel() == 0.
self.log_factor = log_factor
super().__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Unit, _instance)
new.log_factor = self.log_factor.expand(batch_shape)
super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def sample(self, sample_shape=torch.Size()): # noqa: B008
return self.log_factor.new_empty(sample_shape)

def log_prob(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:-1])
return self.log_factor.expand(shape)