16
16
import jax
17
17
import jaxlib
18
18
import jax .numpy as jnp
19
+ import jax .random as jra
20
+ from jax import vmap
19
21
import numpyro
20
22
import numpyro .distributions as dist
21
23
from numpyro .infer import MCMC , NUTS , Predictive , init_to_median
@@ -144,19 +146,44 @@ def sample_from_prior(self, rng_key: jnp.ndarray,
144
146
prior_predictive = Predictive (self .model , num_samples = num_samples )
145
147
samples = prior_predictive (rng_key , X )
146
148
return samples ['y' ]
149
+
150
+ def sample_single_posterior_predictive (self , rng_key , X_new , params , n_draws ):
151
+ sigma = params ["noise" ]
152
+ loc = self ._model (X_new , params )
153
+ sample = dist .Normal (loc , sigma ).sample (rng_key , (n_draws ,)).mean (0 )
154
+ return loc , sample
155
+
156
+ def _vmap_predict (self , rng_key : jnp .ndarray , X_new : jnp .ndarray ,
157
+ samples : Optional [Dict [str , jnp .ndarray ]] = None ,
158
+ n_draws : int = 1 ,
159
+ ) -> Tuple [jnp .ndarray , jnp .ndarray ]:
160
+ """
161
+ Helper method to vectorize predictions over posterior samples
162
+ """
163
+ if samples is None :
164
+ samples = self .get_samples (chain_dim = False )
165
+ num_samples = len (next (iter (samples .values ())))
166
+ vmap_args = (jra .split (rng_key , num_samples ), samples )
167
+
168
+ predictive = lambda p1 , p2 : self .sample_single_posterior_predictive (p1 , X_new , p2 , n_draws )
169
+ loc , f_samples = vmap (predictive )(* vmap_args )
170
+
171
+ return loc , f_samples
147
172
148
173
def predict (self , rng_key : jnp .ndarray , X_new : jnp .ndarray ,
149
174
samples : Optional [Dict [str , jnp .ndarray ]] = None ,
175
+ n : int = 1 ,
150
176
filter_nans : bool = False , take_point_predictions_mean : bool = True ,
151
177
device : Type [jaxlib .xla_extension .Device ] = None
152
178
) -> Tuple [jnp .ndarray , jnp .ndarray ]:
153
179
"""
154
- Make prediction at X_new points using sampled GP hyperparameters
180
+ Make prediction at X_new points using posterior model parameters
155
181
156
182
Args:
157
183
rng_key: random number generator key
158
184
X_new: 2D vector with new/'test' data of :math:`n x num_features` dimensionality
159
185
samples: optional posterior samples
186
+ n: number of samples to draw from normal distribution per single HMC sample
160
187
filter_nans: filter out samples containing NaN values (if any)
161
188
take_point_predictions_mean: take a mean of point predictions (without sampling from the normal distribution)
162
189
device:
@@ -172,10 +199,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
172
199
if device :
173
200
X_new = jax .device_put (X_new , device )
174
201
samples = jax .device_put (samples , device )
175
- predictive = Predictive (
176
- self .model , posterior_samples = samples , parallel = True )
177
- y_pred = predictive (rng_key , X_new )
178
- y_pred , y_sampled = y_pred ["mu" ], y_pred ["y" ]
202
+ y_pred , y_sampled = self ._vmap_predict (rng_key , X_new , samples , n )
179
203
if filter_nans :
180
204
y_sampled_ = [y_i for y_i in y_sampled if not jnp .isnan (y_i ).any ()]
181
205
y_sampled = jnp .array (y_sampled_ )
0 commit comments