Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Fix the use of tensor as arguments of random variables in unit tests …
Browse files Browse the repository at this point in the history
…and tutorials (#1800)

Summary: Pull Request resolved: #1800

Differential Revision: D40856513

fbshipit-source-id: 38f8a60f4dd701cd091208d127b88a54451c8669
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Nov 2, 2022
1 parent f8f2d04 commit 9ef9d44
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 285 deletions.
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ filterwarnings = [
"default:distutils Version classes are deprecated.*:DeprecationWarning",
# statsmodels imports a module that's deprecated since pandas 1.14.0
"default:pandas.Int64Index is deprecated *:FutureWarning",
# functorch 0.1.0 imports deprecated _stateless module
"default:The `torch.nn.utils._stateless` code is deprecated*:DeprecationWarning",
# BM warns against using torch tensors as arguments of random variables
"default:PyTorch tensors are hashed by memory address instead of value.*:UserWarning",
# Arviz warns against the use of deprecated methods, due to the recent release of matplotlib v3.6.0
"default:The register_cmap function will be deprecated in a future version.*:PendingDeprecationWarning",
# gpytorch < 1.9.0 uses torch.triangular_solve
Expand Down
3 changes: 2 additions & 1 deletion src/beanmachine/ppl/model/rv_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __post_init__(self):
warnings.warn(
"PyTorch tensors are hashed by memory address instead of value. "
"Therefore, it is not recommended to use tensors as indices of random variables.",
stacklevel=3,
# display the warning on where the RVIdentifier is created
stacklevel=5,
)

def __str__(self):
Expand Down
12 changes: 12 additions & 0 deletions tests/ppl/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import pytest

# Ignore all warnings in this module against using tensor as arguments of random
# variables
pytestmark = pytest.mark.filterwarnings(
"ignore:PyTorch tensors are hashed by memory address*:UserWarning"
)
2 changes: 1 addition & 1 deletion tests/ppl/compiler/gaussian_mixture_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def category(item):

@bm.random_variable
def mixed(item):
return Normal(mean(category(item)), 2)
return Normal(mean(category(item).item()), 2)


class GaussianMixtureModelTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ppl/compiler/gmm_1d_2comp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def component(self, i):

@bm.random_variable
def y(self, i):
c = self.component(i)
c = self.component(i).item()
return dist.Normal(self.mu(c), self.sigma(c))


Expand Down
4 changes: 2 additions & 2 deletions tests/ppl/compiler/support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def cat_or_bern(n):

@bm.functional
def switch_inf():
return normal_or_bern(flip1(0))
return normal_or_bern(flip1(0).item())


@bm.functional
def switch_4():
return cat_or_bern(flip1(0))
return cat_or_bern(flip1(0).item())


class NodeSupportTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def component(self, i):

@bm.random_variable
def y(self, i):
c = self.component(i)
c = self.component(i).item()
return dist.Normal(self.mu(c), self.sigma(c))


Expand Down
26 changes: 20 additions & 6 deletions tests/ppl/compiler/tutorial_Robust_Linear_Regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# TODO: Check imports for conistency

import beanmachine.ppl as bm
import pytest
import torch # from torch import manual_seed, tensor
import torch.distributions as dist # from torch.distributions import Bernoulli, Normal, Uniform
from beanmachine.ppl.distributions import Flat
from beanmachine.ppl.inference.bmg_inference import BMGInference
from sklearn import model_selection
from torch import tensor
Expand Down Expand Up @@ -65,12 +67,19 @@ def df_nu():


@bm.random_variable
def y_robust(X):
def X():
return Flat()


@bm.random_variable
def y_robust():
"""
Heavy-Tailed Noise model for regression utilizing StudentT
Student's T : https://en.wikipedia.org/wiki/Student%27s_t-distribution
"""
return dist.StudentT(df=df_nu(), loc=beta() * X + alpha(), scale=sigma_regressor())
return dist.StudentT(
df=df_nu(), loc=beta() * X() + alpha(), scale=sigma_regressor()
)


# Creating sample data
Expand All @@ -88,8 +97,9 @@ def y_robust(X):

dist_clean = dist.MultivariateNormal(loc=torch.zeros(2), covariance_matrix=cov)
points = tensor([dist_clean.sample().tolist() for i in range(N)]).view(N, 2)
X = X_clean = points[:, 0]
Y = Y_clean = points[:, 1]

X_clean = points[:, 0]
Y_clean = points[:, 1]

true_beta_1 = 2.0
true_beta_0 = 5.0
Expand All @@ -102,7 +112,7 @@ def y_robust(X):
X_corr = points_noisy[:, 0]
Y_corr = points_noisy[:, 1]

X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y)
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X_corr, Y_corr)

# Inference parameters

Expand All @@ -111,7 +121,7 @@ def y_robust(X):
)
num_chains = 4

observations = {y_robust(X_train): Y_train}
observations = {y_robust(): Y_train, X(): X_train}

queries = [beta(), alpha(), sigma_regressor(), df_nu()]

Expand All @@ -137,6 +147,10 @@ def test_tutorial_Robust_Linear_Regression(self) -> None:

self.assertTrue(True, msg="We just want to check this point is reached")

# TODO: re-enable once we can compile Flat distribution
@pytest.mark.xfail(
raises=TypeError, reason="Flat distribution not supported by BMG yet"
)
def test_tutorial_Robust_Linear_Regression_to_dot_cpp_python(
self,
) -> None:
Expand Down
11 changes: 6 additions & 5 deletions tests/ppl/inference/predictive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def likelihood_2_vec(self, i):
return dist.Bernoulli(self.prior_2())

@bm.random_variable
def likelihood_reg(self, x):
return dist.Normal(self.prior() * x, torch.tensor(1.0))
def likelihood_reg(self):
return dist.Normal(self.prior() * self.x, torch.tensor(1.0))

def test_prior_predictive(self):
queries = [self.prior(), self.likelihood()]
Expand Down Expand Up @@ -103,13 +103,14 @@ def test_predictive_dynamic(self):
def test_predictive_data(self):
x = torch.randn(4)
y = torch.randn(4) + 2.0
obs = {self.likelihood_reg(x): y}
self.x = x
obs = {self.likelihood_reg(): y}
post_samples = bm.SingleSiteAncestralMetropolisHastings().infer(
[self.prior()], obs, num_samples=10, num_chains=2
)
assert post_samples[self.prior()].shape == (2, 10)
test_x = torch.randn(4, 1, 1)
test_query = self.likelihood_reg(test_x)
self.x = torch.randn(4, 1, 1)
test_query = self.likelihood_reg()
predictives = bm.simulate([test_query], post_samples, vectorized=True)
assert predictives[test_query].shape == (4, 2, 10)

Expand Down
Loading

0 comments on commit 9ef9d44

Please sign in to comment.