Skip to content

Commit

Permalink
Fix the use of tensor as arguments of random variables in unit tests …
Browse files Browse the repository at this point in the history
…and tutorials (facebookresearch#1800)

Summary: Pull Request resolved: facebookresearch#1800

Differential Revision: D40856513

fbshipit-source-id: bcc04efd8a47b21759a33dc3f3bf5bc342504537
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Nov 2, 2022
1 parent f8f2d04 commit ad67ecc
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 289 deletions.
8 changes: 0 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ filterwarnings = [
"ignore::DeprecationWarning:nbval",
# PyTorch 1.10 warns against creating a tensor from a list of numpy arrays
"default:Creating a tensor from a list of numpy.ndarrays is extremely slow.*:UserWarning",
# xarray uses a module that's deprecated since setuptools 60.0.0. This has been
# fixed in xarray/pull/6096, so we can remove this filter with the next xarray
# release
"default:distutils Version classes are deprecated.*:DeprecationWarning",
# statsmodels imports a module that's deprecated since pandas 1.14.0
"default:pandas.Int64Index is deprecated *:FutureWarning",
# functorch 0.1.0 imports deprecated _stateless module
"default:The `torch.nn.utils._stateless` code is deprecated*:DeprecationWarning",
# BM warns against using torch tensors as arguments of random variables
"default:PyTorch tensors are hashed by memory address instead of value.*:UserWarning",
# Arviz warns against the use of deprecated methods, due to the recent release of matplotlib v3.6.0
"default:The register_cmap function will be deprecated in a future version.*:PendingDeprecationWarning",
# gpytorch < 1.9.0 uses torch.triangular_solve
Expand Down
14 changes: 11 additions & 3 deletions src/beanmachine/ppl/compiler/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def flip():
"""

import inspect

import warnings
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

Expand Down Expand Up @@ -298,9 +300,15 @@ def _handle_random_variable_call_checked(
(i for i, arg in enumerate(arguments) if isinstance(arg, BMGNode)), -1
)
if index == -1:
# There were no graph node arguments. Just make an ordinary
# function call
rv = function(*arguments)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="PyTorch tensors are hashed by memory address",
)
# There were no graph node arguments. Just make an ordinary
# function call
rv = function(*arguments)
assert isinstance(rv, RVIdentifier)
return self._rv_to_node(rv)

Expand Down
3 changes: 2 additions & 1 deletion src/beanmachine/ppl/model/rv_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __post_init__(self):
warnings.warn(
"PyTorch tensors are hashed by memory address instead of value. "
"Therefore, it is not recommended to use tensors as indices of random variables.",
stacklevel=3,
# display the warning on where the RVIdentifier is created
stacklevel=5,
)

def __str__(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/ppl/compiler/gaussian_mixture_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def category(item):

@bm.random_variable
def mixed(item):
return Normal(mean(category(item)), 2)
return Normal(mean(category(item).item()), 2)


class GaussianMixtureModelTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ppl/compiler/gmm_1d_2comp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def component(self, i):

@bm.random_variable
def y(self, i):
c = self.component(i)
c = self.component(i).item()
return dist.Normal(self.mu(c), self.sigma(c))


Expand Down
4 changes: 2 additions & 2 deletions tests/ppl/compiler/support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def cat_or_bern(n):

@bm.functional
def switch_inf():
return normal_or_bern(flip1(0))
return normal_or_bern(flip1(0).item())


@bm.functional
def switch_4():
return cat_or_bern(flip1(0))
return cat_or_bern(flip1(0).item())


class NodeSupportTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def component(self, i):

@bm.random_variable
def y(self, i):
c = self.component(i)
c = self.component(i).item()
return dist.Normal(self.mu(c), self.sigma(c))


Expand Down
26 changes: 20 additions & 6 deletions tests/ppl/compiler/tutorial_Robust_Linear_Regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# TODO: Check imports for conistency

import beanmachine.ppl as bm
import pytest
import torch # from torch import manual_seed, tensor
import torch.distributions as dist # from torch.distributions import Bernoulli, Normal, Uniform
from beanmachine.ppl.distributions import Flat
from beanmachine.ppl.inference.bmg_inference import BMGInference
from sklearn import model_selection
from torch import tensor
Expand Down Expand Up @@ -65,12 +67,19 @@ def df_nu():


@bm.random_variable
def y_robust(X):
def X():
return Flat()


@bm.random_variable
def y_robust():
"""
Heavy-Tailed Noise model for regression utilizing StudentT
Student's T : https://en.wikipedia.org/wiki/Student%27s_t-distribution
"""
return dist.StudentT(df=df_nu(), loc=beta() * X + alpha(), scale=sigma_regressor())
return dist.StudentT(
df=df_nu(), loc=beta() * X() + alpha(), scale=sigma_regressor()
)


# Creating sample data
Expand All @@ -88,8 +97,9 @@ def y_robust(X):

dist_clean = dist.MultivariateNormal(loc=torch.zeros(2), covariance_matrix=cov)
points = tensor([dist_clean.sample().tolist() for i in range(N)]).view(N, 2)
X = X_clean = points[:, 0]
Y = Y_clean = points[:, 1]

X_clean = points[:, 0]
Y_clean = points[:, 1]

true_beta_1 = 2.0
true_beta_0 = 5.0
Expand All @@ -102,7 +112,7 @@ def y_robust(X):
X_corr = points_noisy[:, 0]
Y_corr = points_noisy[:, 1]

X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y)
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X_corr, Y_corr)

# Inference parameters

Expand All @@ -111,7 +121,7 @@ def y_robust(X):
)
num_chains = 4

observations = {y_robust(X_train): Y_train}
observations = {y_robust(): Y_train, X(): X_train}

queries = [beta(), alpha(), sigma_regressor(), df_nu()]

Expand All @@ -137,6 +147,10 @@ def test_tutorial_Robust_Linear_Regression(self) -> None:

self.assertTrue(True, msg="We just want to check this point is reached")

# TODO: re-enable once we can compile Flat distribution
@pytest.mark.xfail(
raises=TypeError, reason="Flat distribution not supported by BMG yet"
)
def test_tutorial_Robust_Linear_Regression_to_dot_cpp_python(
self,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/ppl/inference/predictive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def test_predictive_dynamic(self):
def test_predictive_data(self):
x = torch.randn(4)
y = torch.randn(4) + 2.0
obs = {self.likelihood_reg(x): y}
obs = {self.likelihood_reg(x.item()): y}
post_samples = bm.SingleSiteAncestralMetropolisHastings().infer(
[self.prior()], obs, num_samples=10, num_chains=2
)
assert post_samples[self.prior()].shape == (2, 10)
test_x = torch.randn(4, 1, 1)
test_query = self.likelihood_reg(test_x)
test_query = self.likelihood_reg(test_x.item())
predictives = bm.simulate([test_query], post_samples, vectorized=True)
assert predictives[test_query].shape == (4, 2, 10)

Expand Down
Loading

0 comments on commit ad67ecc

Please sign in to comment.