Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix posterior with observation noise in batched MTGP models #2782

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator
from torch import Tensor
from torch import broadcast_shapes, Tensor

if TYPE_CHECKING:
from botorch.posteriors.posterior_list import PosteriorList # pragma: no cover
@@ -858,15 +858,23 @@ def _apply_noise(
# get task features for training points
train_task_features = self.train_inputs[0][..., self._task_feature]
train_task_features = self._map_tasks(train_task_features).long()
noise_by_task = torch.zeros(self.num_tasks, dtype=X.dtype, device=X.device)
noise_by_task = torch.zeros(
*self.batch_shape, self.num_tasks, dtype=X.dtype, device=X.device
)
for task_feature in unique_test_task_features:
mask = train_task_features == task_feature
noise_by_task[task_feature] = self.likelihood.noise[mask].mean(
dim=-1, keepdim=True
)
noise_by_task[..., task_feature] = self.likelihood.noise[
..., mask
].mean(dim=-1)
# noise_shape is `broadcast(test_batch_shape, model.batch_shape) x q`
noise_shape = X.shape[:-1]
observation_noise = noise_by_task[test_task_features].expand(noise_shape)
noise_shape = (
broadcast_shapes(X.shape[:-2], self.batch_shape) + X.shape[-2:-1]
)
# Expand and gather ensures we pick correct noise dimensions for
# batch evaluations of batched models.
observation_noise = noise_by_task.expand(*noise_shape[:-1], -1).gather(
dim=-1, index=test_task_features.expand(noise_shape)
)
return self.likelihood(
mvn,
X,
32 changes: 29 additions & 3 deletions test/models/test_fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
@@ -282,9 +282,35 @@ def test_fit_model(
self.assertIsInstance(posterior, GaussianMixturePosterior)
self.assertIsInstance(posterior, GaussianMixturePosterior)

test_X = torch.rand(*batch_shape, d, **tkwargs)
posterior = model.posterior(test_X)
self.assertIsInstance(posterior, GaussianMixturePosterior)
# Test with observation noise.
# Add task index to have variability in added noise.
task_idcs = torch.tensor(
[[i % self.num_tasks] for i in range(batch_shape[-1])],
device=self.device,
)
test_X_w_task = torch.cat(
[test_X, task_idcs.expand(*batch_shape, 1)], dim=-1
)
noise_free_posterior = model.posterior(X=test_X_w_task)
noisy_posterior = model.posterior(X=test_X_w_task, observation_noise=True)
self.assertAllClose(noisy_posterior.mean, noise_free_posterior.mean)
added_noise = noisy_posterior.variance - noise_free_posterior.variance
self.assertTrue(torch.all(added_noise > 0.0))
if infer_noise is False:
# Check that correct noise was added.
train_tasks = train_X[..., 4]
mean_noise_by_task = torch.tensor(
[
train_Yvar[train_tasks == i].mean(dim=0)
for i in train_tasks.unique(sorted=True)
],
device=self.device,
)
expected_noise = mean_noise_by_task[task_idcs]
self.assertAllClose(
added_noise, expected_noise.expand_as(added_noise), atol=1e-4
)

# Mean/variance
expected_shape = (
*batch_shape[: MCMC_DIM + 2],