@@ -42,8 +42,10 @@ class viDKL(ExactGP):
42
42
with ReLU activations by default
43
43
nn_prior:
44
44
Places probabilistic priors over NN weights and biases (Default: True)
45
- latent_prior:
46
- Optional prior over the latent space (NN embedding); uses none by default
45
+ latent_mean_fn:
46
+ Optional mean function over the latent space (NN embedding); uses none by default
47
+ latent_mean_fn_prior:
48
+ Optional latent mean function prior ; uses none by default
47
49
guide:
48
50
Auto-guide option, use 'delta' (default) or 'normal'
49
51
@@ -71,7 +73,8 @@ class viDKL(ExactGP):
71
73
def __init__ (self , input_dim : Union [int , Tuple [int ]], z_dim : int = 2 , kernel : str = 'RBF' ,
72
74
kernel_prior : Optional [Callable [[], Dict [str , jnp .ndarray ]]] = None ,
73
75
nn : Optional [Callable [[jnp .ndarray ], jnp .ndarray ]] = None , nn_prior : bool = True ,
74
- latent_prior : Optional [Callable [[jnp .ndarray ], Dict [str , jnp .ndarray ]]] = None ,
76
+ latent_mean_fn : Optional [Callable [[jnp .ndarray , Dict [str , jnp .ndarray ]], jnp .ndarray ]] = None ,
77
+ latent_mean_fn_prior : Optional [Callable [[jnp .ndarray ], Dict [str , jnp .ndarray ]]] = None ,
75
78
guide : str = 'delta' , ** kwargs
76
79
) -> None :
77
80
super (viDKL , self ).__init__ (input_dim , kernel , None , kernel_prior , ** kwargs )
@@ -82,9 +85,10 @@ def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: st
82
85
self .nn_prior = nn_prior
83
86
self .kernel_dim = z_dim
84
87
self .data_dim = (input_dim ,) if isinstance (input_dim , int ) else input_dim
85
- self .latent_prior = latent_prior
88
+ self .latent_mean_fn = latent_mean_fn
89
+ self .latent_mean_fn_prior = latent_mean_fn_prior
86
90
self .guide_type = AutoNormal if guide == 'normal' else AutoDelta
87
- self .kernel_params = None
91
+ self .gp_params = None
88
92
self .nn_params = None
89
93
90
94
def model (self , X : jnp .ndarray , y : jnp .ndarray = None , ** kwargs ) -> None :
@@ -98,17 +102,20 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None:
98
102
feature_extractor = haiku_module (
99
103
"feature_extractor" , self .nn_module , input_shape = (1 , * self .data_dim ))
100
104
z = feature_extractor (X )
101
- if self .latent_prior : # Sample latent variable
102
- z = self .latent_prior (z )
105
+ # GP's mean function
106
+ f_loc = jnp .zeros (z .shape [0 ])
107
+ if self .latent_mean_fn is not None :
108
+ args = [z ]
109
+ if self .latent_mean_fn_prior is not None :
110
+ args += [self .latent_mean_fn_prior ()]
111
+ f_loc += self .latent_mean_fn (* args ).squeeze ()
103
112
# Sample GP kernel parameters
104
113
if self .kernel_prior :
105
114
kernel_params = self .kernel_prior ()
106
115
else :
107
116
kernel_params = self ._sample_kernel_params ()
108
117
# Sample noise
109
118
noise = self ._sample_noise ()
110
- # GP's mean function
111
- f_loc = jnp .zeros (z .shape [0 ])
112
119
# compute kernel
113
120
k = self .kernel (
114
121
z , z ,
@@ -150,15 +157,15 @@ def single_fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
150
157
# Get NN weights
151
158
nn_params = get_haiku_dict (params_map )
152
159
# Get GP kernel hyperparmeters
153
- kernel_params = {k : v for (k , v ) in params_map .items ()
154
- if not k .startswith ("feature_extractor" )}
160
+ gp_params = {k : v for (k , v ) in params_map .items ()
161
+ if not k .startswith ("feature_extractor" )}
155
162
else : # MLE
156
163
# Get NN weights
157
164
nn_params = params ["feature_extractor$params" ]
158
165
# Get kernel parameters from the guide
159
- kernel_params = svi .guide .median (params )
166
+ gp_params = svi .guide .median (params )
160
167
161
- return nn_params , kernel_params , losses
168
+ return nn_params , gp_params , losses
162
169
163
170
def fit (self , rng_key : jnp .array , X : jnp .ndarray , y : jnp .ndarray ,
164
171
num_steps : int = 1000 , step_size : float = 5e-3 ,
@@ -187,7 +194,7 @@ def _single_fit(yi):
187
194
print_summary = False , progress_bar = False , ** kwargs )
188
195
# Apply vmap to the wrapper function
189
196
vfit = jax .vmap (_single_fit )
190
- self .nn_params , self .kernel_params , self .loss = vfit (y )
197
+ self .nn_params , self .gp_params , self .loss = vfit (y )
191
198
# Poor man version of the progress bar
192
199
if progress_bar :
193
200
avg_bw = [num_steps - num_steps // 20 , num_steps ]
@@ -196,7 +203,7 @@ def _single_fit(yi):
196
203
self .loss .mean (0 )[avg_bw [0 ]:avg_bw [1 ]].mean ().round (4 )))
197
204
198
205
else : # no channel dimension so we use the regular single_fit
199
- self .nn_params , self .kernel_params , self .loss = self .single_fit (
206
+ self .nn_params , self .gp_params , self .loss = self .single_fit (
200
207
rng_key , X , y , num_steps , step_size , print_summary , progress_bar
201
208
)
202
209
if print_summary :
@@ -206,7 +213,7 @@ def _single_fit(yi):
206
213
def get_mvn_posterior (self ,
207
214
X_new : jnp .ndarray ,
208
215
nn_params : Dict [str , jnp .ndarray ],
209
- k_params : Dict [str , jnp .ndarray ],
216
+ gp_params : Dict [str , jnp .ndarray ],
210
217
noiseless : bool = False ,
211
218
y_residual : jnp .ndarray = None ,
212
219
** kwargs
@@ -217,22 +224,35 @@ def get_mvn_posterior(self,
217
224
given a single set of DKL parameters
218
225
"""
219
226
if y_residual is None :
220
- y_residual = self .y_train
221
- noise = k_params ["noise" ]
227
+ y_residual = self .y_train . copy ()
228
+ noise = gp_params ["noise" ]
222
229
noise_p = noise * (1 - jnp .array (noiseless , int ))
230
+
223
231
# embed data into the latent space
224
232
z_train = self .nn_module .apply (
225
233
nn_params , jax .random .PRNGKey (0 ), self .X_train )
226
- z_test = self .nn_module .apply (
234
+ z_new = self .nn_module .apply (
227
235
nn_params , jax .random .PRNGKey (0 ), X_new )
236
+
237
+ # Appply latent mean function
238
+ if self .latent_mean_fn is not None :
239
+ args = [z_train , gp_params ] if self .latent_mean_fn_prior else [z_train ]
240
+ y_residual -= self .latent_mean_fn (* args ).squeeze ()
241
+
228
242
# compute kernel matrices for train and test data
229
- k_pp = self .kernel (z_test , z_test , k_params , noise_p , ** kwargs )
230
- k_pX = self .kernel (z_test , z_train , k_params , jitter = 0.0 )
231
- k_XX = self .kernel (z_train , z_train , k_params , noise , ** kwargs )
243
+ k_pp = self .kernel (z_new , z_new , gp_params , noise_p , ** kwargs )
244
+ k_pX = self .kernel (z_new , z_train , gp_params , jitter = 0.0 )
245
+ k_XX = self .kernel (z_train , z_train , gp_params , noise , ** kwargs )
232
246
# compute the predictive covariance and mean
233
247
K_xx_inv = jnp .linalg .inv (k_XX )
234
248
cov = k_pp - jnp .matmul (k_pX , jnp .matmul (K_xx_inv , jnp .transpose (k_pX )))
235
249
mean = jnp .matmul (k_pX , jnp .matmul (K_xx_inv , y_residual ))
250
+
251
+ # Apply latent mean function
252
+ if self .latent_mean_fn is not None :
253
+ args = [z_new , gp_params ] if self .latent_mean_fn_prior else [z_new ]
254
+ mean += self .latent_mean_fn (* args ).squeeze ()
255
+
236
256
return mean , cov
237
257
238
258
def sample_from_posterior (self , rng_key : jnp .ndarray ,
@@ -246,13 +266,13 @@ def sample_from_posterior(self, rng_key: jnp.ndarray,
246
266
if self .y_train .ndim > 1 :
247
267
raise NotImplementedError ("Currently does not support a multi-channel regime" )
248
268
y_mean , K = self .get_mvn_posterior (
249
- X_new , self .nn_params , self .kernel_params , noiseless , ** kwargs )
269
+ X_new , self .nn_params , self .gp_params , noiseless , ** kwargs )
250
270
y_sampled = dist .MultivariateNormal (y_mean , K ).sample (rng_key , sample_shape = (n ,))
251
271
return y_mean , y_sampled
252
272
253
273
def get_samples (self ) -> Tuple [Dict ['str' , jnp .ndarray ]]:
254
274
"""Returns a tuple with trained NN weights and kernel hyperparameters"""
255
- return self .nn_params , self .kernel_params
275
+ return self .nn_params , self .gp_params
256
276
257
277
def predict_in_batches (self , rng_key : jnp .ndarray ,
258
278
X_new : jnp .ndarray , batch_size : int = 100 ,
@@ -295,24 +315,24 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
295
315
"""
296
316
if params is None :
297
317
nn_params = self .nn_params
298
- k_params = self .kernel_params
318
+ gp_params = self .gp_params
299
319
else :
300
- nn_params , k_params = params
320
+ nn_params , gp_params = params
301
321
302
322
if self .y_train .ndim == 2 : # y has shape (channels, samples)
303
323
# Define a wrapper to use with vmap
304
- def _get_mvn_posterior (nn_params_i , k_params_i , yi ):
324
+ def _get_mvn_posterior (nn_params_i , gp_params_i , yi ):
305
325
mean , cov = self .get_mvn_posterior (
306
- X_new , nn_params_i , k_params_i , noiseless , yi )
326
+ X_new , nn_params_i , gp_params_i , noiseless , yi )
307
327
return mean , cov .diagonal ()
308
328
# vectorize posterior predictive computation over the y's channel dimension
309
329
predictive = jax .vmap (_get_mvn_posterior )
310
- mean , var = predictive (nn_params , k_params , self .y_train )
330
+ mean , var = predictive (nn_params , gp_params , self .y_train )
311
331
312
332
else : # y has shape (samples,)
313
333
# Standard prediction
314
334
mean , cov = self .get_mvn_posterior (
315
- X_new , nn_params , k_params , noiseless )
335
+ X_new , nn_params , gp_params , noiseless )
316
336
var = cov .diagonal ()
317
337
318
338
return mean , var
@@ -384,14 +404,14 @@ def single_embed(nnpar_i, x_i):
384
404
return z
385
405
386
406
def _print_summary (self ) -> None :
387
- if isinstance (self .kernel_params , dict ):
407
+ if isinstance (self .gp_params , dict ):
388
408
print ('\n Inferred GP kernel parameters' )
389
409
if self .X_train .ndim == len (self .data_dim ) + 1 :
390
- for (k , vals ) in self .kernel_params .items ():
410
+ for (k , vals ) in self .gp_params .items ():
391
411
spaces = " " * (15 - len (k ))
392
412
print (k , spaces , jnp .around (vals , 4 ))
393
413
else :
394
- for (k , vals ) in self .kernel_params .items ():
414
+ for (k , vals ) in self .gp_params .items ():
395
415
for i , v in enumerate (vals ):
396
416
spaces = " " * (15 - len (k ))
397
417
print (k + "[{}]" .format (i ), spaces , jnp .around (v , 4 ))
0 commit comments