diff --git a/pyro/distributions/diag_normal_mixture.py b/pyro/distributions/diag_normal_mixture.py index c4f96f0551..52b8725448 100644 --- a/pyro/distributions/diag_normal_mixture.py +++ b/pyro/distributions/diag_normal_mixture.py @@ -46,13 +46,15 @@ class MixtureOfDiagNormals(TorchDistribution): "coord_scale": constraints.positive, "component_logits": constraints.real, } + support = constraints.real_vector def __init__(self, locs, coord_scale, component_logits): self.batch_mode = locs.dim() > 2 assert coord_scale.shape == locs.shape - assert ( - self.batch_mode or locs.dim() == 2 - ), "The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or B x K x D if doing batches)" + assert self.batch_mode or locs.dim() == 2, ( + "The locs parameter in MixtureOfDiagNormals should be K x D dimensional " + "(or ... x B x K x D if doing batches)" + ) if not self.batch_mode: assert ( coord_scale.dim() == 2 @@ -65,10 +67,10 @@ def __init__(self, locs, coord_scale, component_logits): else: assert ( coord_scale.dim() > 2 - ), "The coord_scale parameter in MixtureOfDiagNormals should be B x K x D dimensional" + ), "The coord_scale parameter in MixtureOfDiagNormals should be ... x B x K x D dimensional" assert ( component_logits.dim() > 1 - ), "The component_logits parameter in MixtureOfDiagNormals should be B x K dimensional" + ), "The component_logits parameter in MixtureOfDiagNormals should be ... x B x K dimensional" assert component_logits.size(-1) == locs.size(-2) batch_shape = tuple(locs.shape[:-2]) @@ -133,7 +135,7 @@ class _MixDiagNormalSample(Function): @staticmethod def forward(ctx, locs, scales, component_logits, pis, which, noise_shape): dim = scales.size(-1) - white = locs.new(noise_shape).normal_() + white = locs.new_empty(noise_shape).normal_() n_unsqueezes = locs.dim() - which.dim() for _ in range(n_unsqueezes): which = which.unsqueeze(-1) diff --git a/pyro/distributions/diag_normal_mixture_shared_cov.py b/pyro/distributions/diag_normal_mixture_shared_cov.py index c9c9d23adb..5e5b9fa9aa 100644 --- a/pyro/distributions/diag_normal_mixture_shared_cov.py +++ b/pyro/distributions/diag_normal_mixture_shared_cov.py @@ -45,28 +45,31 @@ class MixtureOfDiagNormalsSharedCovariance(TorchDistribution): "coord_scale": constraints.positive, "component_logits": constraints.real, } + support = constraints.real_vector def __init__(self, locs, coord_scale, component_logits): self.batch_mode = locs.dim() > 2 - assert ( - self.batch_mode or locs.dim() == 2 - ), "The locs parameter in MixtureOfDiagNormals should be K x D dimensional (or ... x B x K x D in batch mode)" + assert self.batch_mode or locs.dim() == 2, ( + "The locs parameter in MixtureOfDiagNormalsSharedCovariance should be K x D dimensional " + "(or ... x B x K x D in batch mode)" + ) if not self.batch_mode: assert ( coord_scale.dim() == 1 - ), "The coord_scale parameter in MixtureOfDiagNormals should be D dimensional" + ), "The coord_scale parameter in MixtureOfDiagNormalsSharedCovariance should be D dimensional" assert ( component_logits.dim() == 1 - ), "The component_logits parameter in MixtureOfDiagNormals should be K dimensional" + ), "The component_logits parameter in MixtureOfDiagNormalsSharedCovariance should be K dimensional" assert component_logits.size(0) == locs.size(0) batch_shape = () else: assert ( coord_scale.dim() > 1 - ), "The coord_scale parameter in MixtureOfDiagNormals should be ... x B x D dimensional" - assert ( - component_logits.dim() > 1 - ), "The component_logits parameter in MixtureOfDiagNormals should be ... x B x K dimensional" + ), "The coord_scale parameter in MixtureOfDiagNormalsSharedCovariance should be ... x B x D dimensional" + assert component_logits.dim() > 1, ( + "The component_logits parameter in MixtureOfDiagNormalsSharedCovariance should be " + "... x B x K dimensional" + ) assert component_logits.size(-1) == locs.size(-2) batch_shape = tuple(locs.shape[:-2]) self.locs = locs @@ -134,7 +137,7 @@ class _MixDiagNormalSharedCovarianceSample(Function): @staticmethod def forward(ctx, locs, coord_scale, component_logits, pis, which, noise_shape): dim = coord_scale.size(-1) - white = torch.randn(noise_shape, dtype=locs.dtype, device=locs.device) + white = locs.new_empty(noise_shape).normal_() n_unsqueezes = locs.dim() - which.dim() for _ in range(n_unsqueezes): which = which.unsqueeze(-1)