Skip to content

Commit f618231

Browse files
committed
Fix prediction with structured probabilistic model
Addresses breaking change in https://github.com/pyro-ppl/numpyro/releases/tag/0.14.0
1 parent 14d1b50 commit f618231

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

gpax/models/spm.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import jax
1717
import jaxlib
1818
import jax.numpy as jnp
19+
import jax.random as jra
20+
from jax import vmap
1921
import numpyro
2022
import numpyro.distributions as dist
2123
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median
@@ -144,19 +146,44 @@ def sample_from_prior(self, rng_key: jnp.ndarray,
144146
prior_predictive = Predictive(self.model, num_samples=num_samples)
145147
samples = prior_predictive(rng_key, X)
146148
return samples['y']
149+
150+
def sample_single_posterior_predictive(self, rng_key, X_new, params, n_draws):
151+
sigma = params["noise"]
152+
loc = self._model(X_new, params)
153+
sample = dist.Normal(loc, sigma).sample(rng_key, (n_draws,)).mean(0)
154+
return loc, sample
155+
156+
def _vmap_predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
157+
samples: Optional[Dict[str, jnp.ndarray]] = None,
158+
n_draws: int = 1,
159+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
160+
"""
161+
Helper method to vectorize predictions over posterior samples
162+
"""
163+
if samples is None:
164+
samples = self.get_samples(chain_dim=False)
165+
num_samples = len(next(iter(samples.values())))
166+
vmap_args = (jra.split(rng_key, num_samples), samples)
167+
168+
predictive = lambda p1, p2: self.sample_single_posterior_predictive(p1, X_new, p2, n_draws)
169+
loc, f_samples = vmap(predictive)(*vmap_args)
170+
171+
return loc, f_samples
147172

148173
def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
149174
samples: Optional[Dict[str, jnp.ndarray]] = None,
175+
n: int = 1,
150176
filter_nans: bool = False, take_point_predictions_mean: bool = True,
151177
device: Type[jaxlib.xla_extension.Device] = None
152178
) -> Tuple[jnp.ndarray, jnp.ndarray]:
153179
"""
154-
Make prediction at X_new points using sampled GP hyperparameters
180+
Make prediction at X_new points using posterior model parameters
155181
156182
Args:
157183
rng_key: random number generator key
158184
X_new: 2D vector with new/'test' data of :math:`n x num_features` dimensionality
159185
samples: optional posterior samples
186+
n: number of samples to draw from normal distribution per single HMC sample
160187
filter_nans: filter out samples containing NaN values (if any)
161188
take_point_predictions_mean: take a mean of point predictions (without sampling from the normal distribution)
162189
device:
@@ -172,10 +199,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
172199
if device:
173200
X_new = jax.device_put(X_new, device)
174201
samples = jax.device_put(samples, device)
175-
predictive = Predictive(
176-
self.model, posterior_samples=samples, parallel=True)
177-
y_pred = predictive(rng_key, X_new)
178-
y_pred, y_sampled = y_pred["mu"], y_pred["y"]
202+
y_pred, y_sampled = self._vmap_predict(rng_key, X_new, samples, n)
179203
if filter_nans:
180204
y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()]
181205
y_sampled = jnp.array(y_sampled_)

tests/test_spm.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,24 @@ def test_get_samples():
5555
def test_prediction():
5656
rng_keys = get_keys()
5757
X, y = get_dummy_data()
58-
X_test, _ = get_dummy_data()
59-
samples = {"a": jax.random.normal(rng_keys[0], shape=(100, 1)),
60-
"b": jax.random.normal(rng_keys[0], shape=(100,))}
58+
X_test = onp.linspace(X.min(), X.max(), 200)
59+
samples = {"a": jax.random.normal(rng_keys[0], shape=(100,)),
60+
"b": jax.random.normal(rng_keys[0], shape=(100,)),
61+
"noise": jax.random.normal(rng_keys[0], shape=(100,))}
6162
m =sPM(model, model_priors)
6263
y_mean, y_sampled = m.predict(rng_keys[1], X_test, samples)
6364
assert isinstance(y_mean, jnp.ndarray)
6465
assert isinstance(y_sampled, jnp.ndarray)
6566
assert_equal(y_mean.shape, X_test.squeeze().shape)
66-
assert_equal(y_sampled.shape, (100, X_test.shape[0]))
67+
assert_equal(y_sampled.shape, (100, X_test.shape[0]))
68+
69+
70+
def test_fit_predict():
71+
key1, key2 = get_keys()
72+
X, y = get_dummy_data()
73+
X_test = onp.linspace(X.min(), X.max(), 200)
74+
m = sPM(model, model_priors)
75+
m.fit(key1, X, y, num_warmup=100, num_samples=100)
76+
y_mean, y_sampled = m.predict(key2, X_test)
77+
assert_equal(y_mean.shape, X_test.squeeze().shape)
78+
assert_equal(y_sampled.shape, (100, X_test.shape[0]))

0 commit comments

Comments
 (0)