Skip to content

Commit

Permalink
Address the issues in #3274 (#3277)
Browse files Browse the repository at this point in the history
* Added `support = constraints.real_vector` to `MixtureOfDiagNormals`
  and `MixtureOfDiagNormalsSharedCovariance`
* Fixed the white noise sampling bug in the `forward` method of
  `_MixDiagNormalSample`
* Same call in `_MixDiagNormalSample` and
  `_MixDiagNormalSharedCovarianceSample` to generate white noise
* Harmonized tensor shape error messages between `MixtureOfDiagNormals`
  and `MixtureOfDiagNormalsSharedCovariance`
* Added the correct class name in tensor shape errors for
  `MixtureOfDiagNormalsSharedCovariance`
  • Loading branch information
cyianor authored Oct 5, 2023
1 parent c00bcc3 commit 41ac46f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
14 changes: 8 additions & 6 deletions pyro/distributions/diag_normal_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ class MixtureOfDiagNormals(TorchDistribution):
"coord_scale": constraints.positive,
"component_logits": constraints.real,
}
support = constraints.real_vector

def __init__(self, locs, coord_scale, component_logits):
self.batch_mode = locs.dim() > 2
assert coord_scale.shape == locs.shape
assert (
self.batch_mode or locs.dim() == 2
), "The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or B x K x D if doing batches)"
assert self.batch_mode or locs.dim() == 2, (
"The locs parameter in MixtureOfDiagNormals should be K x D dimensional "
"(or ... x B x K x D if doing batches)"
)
if not self.batch_mode:
assert (
coord_scale.dim() == 2
Expand All @@ -65,10 +67,10 @@ def __init__(self, locs, coord_scale, component_logits):
else:
assert (
coord_scale.dim() > 2
), "The coord_scale parameter in MixtureOfDiagNormals should be B x K x D dimensional"
), "The coord_scale parameter in MixtureOfDiagNormals should be ... x B x K x D dimensional"
assert (
component_logits.dim() > 1
), "The component_logits parameter in MixtureOfDiagNormals should be B x K dimensional"
), "The component_logits parameter in MixtureOfDiagNormals should be ... x B x K dimensional"
assert component_logits.size(-1) == locs.size(-2)
batch_shape = tuple(locs.shape[:-2])

Expand Down Expand Up @@ -133,7 +135,7 @@ class _MixDiagNormalSample(Function):
@staticmethod
def forward(ctx, locs, scales, component_logits, pis, which, noise_shape):
dim = scales.size(-1)
white = locs.new(noise_shape).normal_()
white = locs.new_empty(noise_shape).normal_()
n_unsqueezes = locs.dim() - which.dim()
for _ in range(n_unsqueezes):
which = which.unsqueeze(-1)
Expand Down
23 changes: 13 additions & 10 deletions pyro/distributions/diag_normal_mixture_shared_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,31 @@ class MixtureOfDiagNormalsSharedCovariance(TorchDistribution):
"coord_scale": constraints.positive,
"component_logits": constraints.real,
}
support = constraints.real_vector

def __init__(self, locs, coord_scale, component_logits):
self.batch_mode = locs.dim() > 2
assert (
self.batch_mode or locs.dim() == 2
), "The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or ... x B x K x D in batch mode)"
assert self.batch_mode or locs.dim() == 2, (
"The locs parameter in MixtureOfDiagNormalsSharedCovariance should be K x D dimensional "
"(or ... x B x K x D in batch mode)"
)
if not self.batch_mode:
assert (
coord_scale.dim() == 1
), "The coord_scale parameter in MixtureOfDiagNormals should be D dimensional"
), "The coord_scale parameter in MixtureOfDiagNormalsSharedCovariance should be D dimensional"
assert (
component_logits.dim() == 1
), "The component_logits parameter in MixtureOfDiagNormals should be K dimensional"
), "The component_logits parameter in MixtureOfDiagNormalsSharedCovariance should be K dimensional"
assert component_logits.size(0) == locs.size(0)
batch_shape = ()
else:
assert (
coord_scale.dim() > 1
), "The coord_scale parameter in MixtureOfDiagNormals should be ... x B x D dimensional"
assert (
component_logits.dim() > 1
), "The component_logits parameter in MixtureOfDiagNormals should be ... x B x K dimensional"
), "The coord_scale parameter in MixtureOfDiagNormalsSharedCovariance should be ... x B x D dimensional"
assert component_logits.dim() > 1, (
"The component_logits parameter in MixtureOfDiagNormalsSharedCovariance should be "
"... x B x K dimensional"
)
assert component_logits.size(-1) == locs.size(-2)
batch_shape = tuple(locs.shape[:-2])
self.locs = locs
Expand Down Expand Up @@ -134,7 +137,7 @@ class _MixDiagNormalSharedCovarianceSample(Function):
@staticmethod
def forward(ctx, locs, coord_scale, component_logits, pis, which, noise_shape):
dim = coord_scale.size(-1)
white = torch.randn(noise_shape, dtype=locs.dtype, device=locs.device)
white = locs.new_empty(noise_shape).normal_()
n_unsqueezes = locs.dim() - which.dim()
for _ in range(n_unsqueezes):
which = which.unsqueeze(-1)
Expand Down

0 comments on commit 41ac46f

Please sign in to comment.