Skip to content

Commit c5b3007

Browse files
authored
Add latent mean function to viDKL
1 parent fba9cde commit c5b3007

File tree

1 file changed

+53
-33
lines changed

1 file changed

+53
-33
lines changed

gpax/models/vidkl.py

+53-33
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ class viDKL(ExactGP):
4242
with ReLU activations by default
4343
nn_prior:
4444
Places probabilistic priors over NN weights and biases (Default: True)
45-
latent_prior:
46-
Optional prior over the latent space (NN embedding); uses none by default
45+
latent_mean_fn:
46+
Optional mean function over the latent space (NN embedding); uses none by default
47+
latent_mean_fn_prior:
48+
Optional latent mean function prior ; uses none by default
4749
guide:
4850
Auto-guide option, use 'delta' (default) or 'normal'
4951
@@ -71,7 +73,8 @@ class viDKL(ExactGP):
7173
def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: str = 'RBF',
7274
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
7375
nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, nn_prior: bool = True,
74-
latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None,
76+
latent_mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
77+
latent_mean_fn_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None,
7578
guide: str = 'delta', **kwargs
7679
) -> None:
7780
super(viDKL, self).__init__(input_dim, kernel, None, kernel_prior, **kwargs)
@@ -82,9 +85,10 @@ def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: st
8285
self.nn_prior = nn_prior
8386
self.kernel_dim = z_dim
8487
self.data_dim = (input_dim,) if isinstance(input_dim, int) else input_dim
85-
self.latent_prior = latent_prior
88+
self.latent_mean_fn = latent_mean_fn
89+
self.latent_mean_fn_prior = latent_mean_fn_prior
8690
self.guide_type = AutoNormal if guide == 'normal' else AutoDelta
87-
self.kernel_params = None
91+
self.gp_params = None
8892
self.nn_params = None
8993

9094
def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None:
@@ -98,17 +102,20 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None:
98102
feature_extractor = haiku_module(
99103
"feature_extractor", self.nn_module, input_shape=(1, *self.data_dim))
100104
z = feature_extractor(X)
101-
if self.latent_prior: # Sample latent variable
102-
z = self.latent_prior(z)
105+
# GP's mean function
106+
f_loc = jnp.zeros(z.shape[0])
107+
if self.latent_mean_fn is not None:
108+
args = [z]
109+
if self.latent_mean_fn_prior is not None:
110+
args += [self.latent_mean_fn_prior()]
111+
f_loc += self.latent_mean_fn(*args).squeeze()
103112
# Sample GP kernel parameters
104113
if self.kernel_prior:
105114
kernel_params = self.kernel_prior()
106115
else:
107116
kernel_params = self._sample_kernel_params()
108117
# Sample noise
109118
noise = self._sample_noise()
110-
# GP's mean function
111-
f_loc = jnp.zeros(z.shape[0])
112119
# compute kernel
113120
k = self.kernel(
114121
z, z,
@@ -150,15 +157,15 @@ def single_fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
150157
# Get NN weights
151158
nn_params = get_haiku_dict(params_map)
152159
# Get GP kernel hyperparmeters
153-
kernel_params = {k: v for (k, v) in params_map.items()
154-
if not k.startswith("feature_extractor")}
160+
gp_params = {k: v for (k, v) in params_map.items()
161+
if not k.startswith("feature_extractor")}
155162
else: # MLE
156163
# Get NN weights
157164
nn_params = params["feature_extractor$params"]
158165
# Get kernel parameters from the guide
159-
kernel_params = svi.guide.median(params)
166+
gp_params = svi.guide.median(params)
160167

161-
return nn_params, kernel_params, losses
168+
return nn_params, gp_params, losses
162169

163170
def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
164171
num_steps: int = 1000, step_size: float = 5e-3,
@@ -187,7 +194,7 @@ def _single_fit(yi):
187194
print_summary=False, progress_bar=False, **kwargs)
188195
# Apply vmap to the wrapper function
189196
vfit = jax.vmap(_single_fit)
190-
self.nn_params, self.kernel_params, self.loss = vfit(y)
197+
self.nn_params, self.gp_params, self.loss = vfit(y)
191198
# Poor man version of the progress bar
192199
if progress_bar:
193200
avg_bw = [num_steps - num_steps // 20, num_steps]
@@ -196,7 +203,7 @@ def _single_fit(yi):
196203
self.loss.mean(0)[avg_bw[0]:avg_bw[1]].mean().round(4)))
197204

198205
else: # no channel dimension so we use the regular single_fit
199-
self.nn_params, self.kernel_params, self.loss = self.single_fit(
206+
self.nn_params, self.gp_params, self.loss = self.single_fit(
200207
rng_key, X, y, num_steps, step_size, print_summary, progress_bar
201208
)
202209
if print_summary:
@@ -206,7 +213,7 @@ def _single_fit(yi):
206213
def get_mvn_posterior(self,
207214
X_new: jnp.ndarray,
208215
nn_params: Dict[str, jnp.ndarray],
209-
k_params: Dict[str, jnp.ndarray],
216+
gp_params: Dict[str, jnp.ndarray],
210217
noiseless: bool = False,
211218
y_residual: jnp.ndarray = None,
212219
**kwargs
@@ -217,22 +224,35 @@ def get_mvn_posterior(self,
217224
given a single set of DKL parameters
218225
"""
219226
if y_residual is None:
220-
y_residual = self.y_train
221-
noise = k_params["noise"]
227+
y_residual = self.y_train.copy()
228+
noise = gp_params["noise"]
222229
noise_p = noise * (1 - jnp.array(noiseless, int))
230+
223231
# embed data into the latent space
224232
z_train = self.nn_module.apply(
225233
nn_params, jax.random.PRNGKey(0), self.X_train)
226-
z_test = self.nn_module.apply(
234+
z_new = self.nn_module.apply(
227235
nn_params, jax.random.PRNGKey(0), X_new)
236+
237+
# Appply latent mean function
238+
if self.latent_mean_fn is not None:
239+
args = [z_train, gp_params] if self.latent_mean_fn_prior else [z_train]
240+
y_residual -= self.latent_mean_fn(*args).squeeze()
241+
228242
# compute kernel matrices for train and test data
229-
k_pp = self.kernel(z_test, z_test, k_params, noise_p, **kwargs)
230-
k_pX = self.kernel(z_test, z_train, k_params, jitter=0.0)
231-
k_XX = self.kernel(z_train, z_train, k_params, noise, **kwargs)
243+
k_pp = self.kernel(z_new, z_new, gp_params, noise_p, **kwargs)
244+
k_pX = self.kernel(z_new, z_train, gp_params, jitter=0.0)
245+
k_XX = self.kernel(z_train, z_train, gp_params, noise, **kwargs)
232246
# compute the predictive covariance and mean
233247
K_xx_inv = jnp.linalg.inv(k_XX)
234248
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
235249
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))
250+
251+
# Apply latent mean function
252+
if self.latent_mean_fn is not None:
253+
args = [z_new, gp_params] if self.latent_mean_fn_prior else [z_new]
254+
mean += self.latent_mean_fn(*args).squeeze()
255+
236256
return mean, cov
237257

238258
def sample_from_posterior(self, rng_key: jnp.ndarray,
@@ -246,13 +266,13 @@ def sample_from_posterior(self, rng_key: jnp.ndarray,
246266
if self.y_train.ndim > 1:
247267
raise NotImplementedError("Currently does not support a multi-channel regime")
248268
y_mean, K = self.get_mvn_posterior(
249-
X_new, self.nn_params, self.kernel_params, noiseless, **kwargs)
269+
X_new, self.nn_params, self.gp_params, noiseless, **kwargs)
250270
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
251271
return y_mean, y_sampled
252272

253273
def get_samples(self) -> Tuple[Dict['str', jnp.ndarray]]:
254274
"""Returns a tuple with trained NN weights and kernel hyperparameters"""
255-
return self.nn_params, self.kernel_params
275+
return self.nn_params, self.gp_params
256276

257277
def predict_in_batches(self, rng_key: jnp.ndarray,
258278
X_new: jnp.ndarray, batch_size: int = 100,
@@ -295,24 +315,24 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
295315
"""
296316
if params is None:
297317
nn_params = self.nn_params
298-
k_params = self.kernel_params
318+
gp_params = self.gp_params
299319
else:
300-
nn_params, k_params = params
320+
nn_params, gp_params = params
301321

302322
if self.y_train.ndim == 2: # y has shape (channels, samples)
303323
# Define a wrapper to use with vmap
304-
def _get_mvn_posterior(nn_params_i, k_params_i, yi):
324+
def _get_mvn_posterior(nn_params_i, gp_params_i, yi):
305325
mean, cov = self.get_mvn_posterior(
306-
X_new, nn_params_i, k_params_i, noiseless, yi)
326+
X_new, nn_params_i, gp_params_i, noiseless, yi)
307327
return mean, cov.diagonal()
308328
# vectorize posterior predictive computation over the y's channel dimension
309329
predictive = jax.vmap(_get_mvn_posterior)
310-
mean, var = predictive(nn_params, k_params, self.y_train)
330+
mean, var = predictive(nn_params, gp_params, self.y_train)
311331

312332
else: # y has shape (samples,)
313333
# Standard prediction
314334
mean, cov = self.get_mvn_posterior(
315-
X_new, nn_params, k_params, noiseless)
335+
X_new, nn_params, gp_params, noiseless)
316336
var = cov.diagonal()
317337

318338
return mean, var
@@ -384,14 +404,14 @@ def single_embed(nnpar_i, x_i):
384404
return z
385405

386406
def _print_summary(self) -> None:
387-
if isinstance(self.kernel_params, dict):
407+
if isinstance(self.gp_params, dict):
388408
print('\nInferred GP kernel parameters')
389409
if self.X_train.ndim == len(self.data_dim) + 1:
390-
for (k, vals) in self.kernel_params.items():
410+
for (k, vals) in self.gp_params.items():
391411
spaces = " " * (15 - len(k))
392412
print(k, spaces, jnp.around(vals, 4))
393413
else:
394-
for (k, vals) in self.kernel_params.items():
414+
for (k, vals) in self.gp_params.items():
395415
for i, v in enumerate(vals):
396416
spaces = " " * (15 - len(k))
397417
print(k+"[{}]".format(i), spaces, jnp.around(v, 4))

0 commit comments

Comments
 (0)