Skip to content

Commit 248fe84

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Move unsqueeze application to posterior call of Fully Bayesian MTGP (pytorch#2781)
Summary: This is consistent with the fully Bayesian STGP models and ensures that any other module calls within `posterior` also receive the expanded inputs. Differential Revision: D71643889
1 parent bc4b0c6 commit 248fe84

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

botorch/models/fully_bayesian_multitask.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def posterior(
361361
"""
362362
self._check_if_fitted()
363363
posterior = super().posterior(
364-
X=X,
364+
X=X.unsqueeze(MCMC_DIM),
365365
output_indices=output_indices,
366366
observation_noise=observation_noise,
367367
posterior_transform=posterior_transform,
@@ -372,14 +372,14 @@ def posterior(
372372

373373
def forward(self, X: Tensor) -> MultivariateNormal:
374374
self._check_if_fitted()
375-
X = X.unsqueeze(MCMC_DIM)
376-
377375
x_basic, task_idcs = self._split_inputs(X)
378376

379377
mean_x = self.mean_module(x_basic)
380378
covar_x = self.covar_module(x_basic)
381379

382-
tsub_idcs = task_idcs.squeeze(-3).squeeze(-1)
380+
tsub_idcs = task_idcs.squeeze(-1)
381+
if tsub_idcs.ndim > 1:
382+
tsub_idcs = tsub_idcs.squeeze(-2)
383383
latent_features = self.latent_features[:, tsub_idcs, :]
384384

385385
if X.ndim > 3:

0 commit comments

Comments
 (0)