Skip to content

Commit ed81a46

Browse files
committed
Added ES optimization initializer
1 parent 9a7c517 commit ed81a46

File tree

3 files changed

+249
-54
lines changed

3 files changed

+249
-54
lines changed

botorch/optim/initializers.py

+145-54
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from botorch.acquisition import analytic, monte_carlo, multi_objective
2525
from botorch.acquisition.acquisition import AcquisitionFunction
2626
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
27+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
2728
from botorch.acquisition.knowledge_gradient import (
2829
_get_value_function,
2930
qKnowledgeGradient,
@@ -468,6 +469,89 @@ def gen_batch_initial_conditions(
468469
return batch_initial_conditions
469470

470471

472+
def gen_optimal_input_initial_conditions(
473+
acq_function: AcquisitionFunction,
474+
bounds: Tensor,
475+
q: int,
476+
num_restarts: int,
477+
raw_samples: int,
478+
fixed_features: dict[int, float] | None = None,
479+
options: dict[str, bool | float | int] | None = None,
480+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
481+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
482+
):
483+
device = bounds.device
484+
if not hasattr(acq_function, "optimal_inputs"):
485+
raise AttributeError(
486+
"gen_optimal_input_initial_conditions can only be used with "
487+
"an AcquisitionFunction that has an optimal_inputs attribute."
488+
)
489+
frac_random: float = options.get("frac_random", 0.0)
490+
if not 0 <= frac_random <= 1:
491+
raise ValueError(
492+
f"frac_random must take on values in (0,1). Value: {frac_random}"
493+
)
494+
495+
batch_limit = options.get("batch_limit")
496+
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
497+
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
498+
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
499+
num_random = round(raw_samples * frac_random)
500+
if num_random > 0:
501+
X_rnd = sample_q_batches_from_polytope(
502+
n=num_random,
503+
q=q,
504+
bounds=bounds,
505+
n_burnin=options.get("n_burnin", 10000),
506+
n_thinning=options.get("n_thinning", 32),
507+
equality_constraints=equality_constraints,
508+
inequality_constraints=inequality_constraints,
509+
)
510+
X = torch.cat((X, X_rnd))
511+
512+
if num_random < raw_samples:
513+
X_perturbed = sample_points_around_best(
514+
acq_function=acq_function,
515+
n_discrete_points=q * (raw_samples - num_random),
516+
sigma=options.get("sample_around_best_sigma", 1e-2),
517+
bounds=bounds,
518+
best_X=suggestions,
519+
)
520+
X_perturbed = X_perturbed.view(
521+
raw_samples - num_random, q, bounds.shape[-1]
522+
).cpu()
523+
X = torch.cat((X, X_perturbed))
524+
525+
if options.get("sample_around_best", False):
526+
X_best = sample_points_around_best(
527+
acq_function=acq_function,
528+
n_discrete_points=q * raw_samples,
529+
sigma=options.get("sample_around_best_sigma", 1e-2),
530+
bounds=bounds,
531+
)
532+
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
533+
X = torch.cat((X, X_best))
534+
535+
with torch.no_grad():
536+
if batch_limit is None:
537+
batch_limit = X.shape[0]
538+
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
539+
# sized chunks.
540+
acq_vals = torch.cat(
541+
[
542+
acq_function(x_.to(device=device)).cpu()
543+
for x_ in X.split(split_size=batch_limit, dim=0)
544+
],
545+
dim=0,
546+
)
547+
548+
eta = options.get("eta", 2.0)
549+
weights = torch.exp(eta * standardize(acq_vals))
550+
idx = torch.multinomial(weights, num_restarts, replacement=True)
551+
552+
return X[idx]
553+
554+
471555
def gen_one_shot_kg_initial_conditions(
472556
acq_function: qKnowledgeGradient,
473557
bounds: Tensor,
@@ -1141,6 +1225,7 @@ def sample_points_around_best(
11411225
best_pct: float = 5.0,
11421226
subset_sigma: float = 1e-1,
11431227
prob_perturb: float | None = None,
1228+
best_X: Tensor | None = None,
11441229
) -> Tensor | None:
11451230
r"""Find best points and sample nearby points.
11461231
@@ -1154,65 +1239,71 @@ def sample_points_around_best(
11541239
subset_sigma: The standard deviation of the additive gaussian
11551240
noise for perturbing a subset of dimensions of the best points.
11561241
prob_perturb: The probability of perturbing each dimension.
1242+
best_X: A provided set of best points to sample around. If None, the
1243+
set is instead inferred. Used for e.g. info-theoretic acquisition
1244+
functions, where the sampled optima serve as suggestions for
1245+
acquisition function optimization.
11571246
11581247
Returns:
11591248
An optional `n_discrete_points x d`-dim tensor containing the
11601249
sampled points. This is None if no baseline points are found.
11611250
"""
1162-
X = get_X_baseline(acq_function=acq_function)
1163-
if X is None:
1164-
return
1165-
with torch.no_grad():
1166-
try:
1167-
posterior = acq_function.model.posterior(X)
1168-
except AttributeError:
1169-
warnings.warn(
1170-
"Failed to sample around previous best points.",
1171-
BotorchWarning,
1172-
stacklevel=3,
1173-
)
1251+
if best_X is None:
1252+
X = get_X_baseline(acq_function=acq_function)
1253+
if X is None:
11741254
return
1175-
mean = posterior.mean
1176-
while mean.ndim > 2:
1177-
# take average over batch dims
1178-
mean = mean.mean(dim=0)
1179-
try:
1180-
f_pred = acq_function.objective(mean)
1181-
# Some acquisition functions do not have an objective
1182-
# and for some acquisition functions the objective is None
1183-
except (AttributeError, TypeError):
1184-
f_pred = mean
1185-
if hasattr(acq_function, "maximize"):
1186-
# make sure that the optimiztaion direction is set properly
1187-
if not acq_function.maximize:
1188-
f_pred = -f_pred
1189-
try:
1190-
# handle constraints for EHVI-based acquisition functions
1191-
constraints = acq_function.constraints
1192-
if constraints is not None:
1193-
neg_violation = -torch.stack(
1194-
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1195-
).sum(dim=-1)
1196-
feas = neg_violation == 0
1197-
if feas.any():
1198-
f_pred[~feas] = float("-inf")
1199-
else:
1200-
# set objective equal to negative violation
1201-
f_pred = neg_violation
1202-
except AttributeError:
1203-
pass
1204-
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1205-
# multi-objective
1206-
# find pareto set
1207-
is_pareto = is_non_dominated(f_pred)
1208-
best_X = X[is_pareto]
1209-
else:
1210-
if f_pred.shape[-1] == 1:
1211-
f_pred = f_pred.squeeze(-1)
1212-
n_best = max(1, round(X.shape[0] * best_pct / 100))
1213-
# the view() is to ensure that best_idcs is not a scalar tensor
1214-
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1215-
best_X = X[best_idcs]
1255+
with torch.no_grad():
1256+
try:
1257+
posterior = acq_function.model.posterior(X)
1258+
except AttributeError:
1259+
warnings.warn(
1260+
"Failed to sample around previous best points.",
1261+
BotorchWarning,
1262+
stacklevel=3,
1263+
)
1264+
return
1265+
mean = posterior.mean
1266+
while mean.ndim > 2:
1267+
# take average over batch dims
1268+
mean = mean.mean(dim=0)
1269+
try:
1270+
f_pred = acq_function.objective(mean)
1271+
# Some acquisition functions do not have an objective
1272+
# and for some acquisition functions the objective is None
1273+
except (AttributeError, TypeError):
1274+
f_pred = mean
1275+
if hasattr(acq_function, "maximize"):
1276+
# make sure that the optimiztaion direction is set properly
1277+
if not acq_function.maximize:
1278+
f_pred = -f_pred
1279+
try:
1280+
# handle constraints for EHVI-based acquisition functions
1281+
constraints = acq_function.constraints
1282+
if constraints is not None:
1283+
neg_violation = -torch.stack(
1284+
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1285+
).sum(dim=-1)
1286+
feas = neg_violation == 0
1287+
if feas.any():
1288+
f_pred[~feas] = float("-inf")
1289+
else:
1290+
# set objective equal to negative violation
1291+
f_pred = neg_violation
1292+
except AttributeError:
1293+
pass
1294+
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1295+
# multi-objective
1296+
# find pareto set
1297+
is_pareto = is_non_dominated(f_pred)
1298+
best_X = X[is_pareto]
1299+
else:
1300+
if f_pred.shape[-1] == 1:
1301+
f_pred = f_pred.squeeze(-1)
1302+
n_best = max(1, round(X.shape[0] * best_pct / 100))
1303+
# the view() is to ensure that best_idcs is not a scalar tensor
1304+
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1305+
best_X = X[best_idcs]
1306+
12161307
use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
12171308
n_trunc_normal_points = (
12181309
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
@@ -1234,7 +1325,7 @@ def sample_points_around_best(
12341325
)
12351326
perturbed_X = torch.cat([perturbed_X, perturbed_subset_dims_X], dim=0)
12361327
# shuffle points
1237-
perm = torch.randperm(perturbed_X.shape[0], device=X.device)
1328+
perm = torch.randperm(perturbed_X.shape[0], device=best_X.device)
12381329
perturbed_X = perturbed_X[perm]
12391330
return perturbed_X
12401331

botorch/optim/optimize.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
AcquisitionFunction,
2121
OneShotAcquisitionFunction,
2222
)
23+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
2324
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2425
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
2526
qHypervolumeKnowledgeGradient,
@@ -33,6 +34,7 @@
3334
gen_batch_initial_conditions,
3435
gen_one_shot_hvkg_initial_conditions,
3536
gen_one_shot_kg_initial_conditions,
37+
gen_optimal_input_initial_conditions,
3638
TGenInitialConditions,
3739
)
3840
from botorch.optim.stopping import ExpMAStoppingCriterion
@@ -174,6 +176,8 @@ def get_ic_generator(self) -> TGenInitialConditions:
174176
return gen_one_shot_kg_initial_conditions
175177
elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient):
176178
return gen_one_shot_hvkg_initial_conditions
179+
elif isinstance(self.acq_function, qJointEntropySearch):
180+
return gen_optimal_input_initial_conditions
177181
return gen_batch_initial_conditions
178182

179183

test/optim/test_initializers.py

+100
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
from botorch.acquisition.analytic import PosteriorMean
1515
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
16+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
1617
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
1718
from botorch.acquisition.monte_carlo import (
1819
qExpectedImprovement,
@@ -34,6 +35,7 @@
3435
gen_batch_initial_conditions,
3536
gen_one_shot_hvkg_initial_conditions,
3637
gen_one_shot_kg_initial_conditions,
38+
gen_optimal_input_initial_conditions,
3739
gen_value_function_initial_conditions,
3840
initialize_q_batch,
3941
initialize_q_batch_nonneg,
@@ -48,6 +50,7 @@
4850
)
4951
from botorch.sampling.normal import IIDNormalSampler
5052
from botorch.utils.sampling import draw_sobol_samples, manual_seed, unnormalize
53+
from botorch.utils.test_helpers import get_model
5154
from botorch.utils.testing import (
5255
_get_max_violation_of_bounds,
5356
_get_max_violation_of_constraints,
@@ -1075,6 +1078,88 @@ def test_gen_one_shot_kg_initial_conditions(self):
10751078
)
10761079
self.assertTrue(torch.all(ics[..., -n_value:, :] == 1))
10771080

1081+
def test_gen_optimal_input_initial_conditions(self):
1082+
num_restarts = 10
1083+
raw_samples = 16
1084+
for dtype in (torch.float, torch.double):
1085+
model = get_model(
1086+
torch.rand(4, 2, dtype=dtype), torch.rand(4, 1, dtype=dtype)
1087+
)
1088+
optimal_inputs = torch.rand(5, 2, dtype=dtype)
1089+
optimal_outputs = torch.rand(5, 1, dtype=dtype)
1090+
jes = qJointEntropySearch(
1091+
model=model,
1092+
optimal_inputs=optimal_inputs,
1093+
optimal_outputs=optimal_outputs,
1094+
)
1095+
bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype)
1096+
# test option error
1097+
with self.assertRaises(ValueError):
1098+
gen_optimal_input_initial_conditions(
1099+
acq_function=jes,
1100+
bounds=bounds,
1101+
q=1,
1102+
num_restarts=num_restarts,
1103+
raw_samples=raw_samples,
1104+
options={"frac_random": 2.0},
1105+
)
1106+
1107+
ei = qExpectedImprovement(model, 99.9)
1108+
with self.assertRaisesRegex(
1109+
AttributeError,
1110+
"gen_optimal_input_initial_conditions can only be used with "
1111+
"an AcquisitionFunction that has an optimal_inputs attribute.",
1112+
):
1113+
gen_optimal_input_initial_conditions(
1114+
acq_function=ei,
1115+
bounds=bounds,
1116+
q=1,
1117+
num_restarts=num_restarts,
1118+
raw_samples=raw_samples,
1119+
options={"frac_random": 2.0},
1120+
)
1121+
# test generation logic
1122+
q = 3
1123+
random_ics = torch.rand(raw_samples // 2, q, 2)
1124+
suggested_ics = torch.rand(raw_samples // 2 * q, 2)
1125+
with ExitStack() as es:
1126+
mock_random_ics = es.enter_context(
1127+
mock.patch(
1128+
"botorch.optim.initializers.sample_q_batches_from_polytope",
1129+
return_value=random_ics,
1130+
)
1131+
)
1132+
mock_suggested_ics = es.enter_context(
1133+
mock.patch(
1134+
"botorch.optim.initializers.sample_points_around_best",
1135+
return_value=suggested_ics,
1136+
)
1137+
)
1138+
mock_choose = es.enter_context(
1139+
mock.patch(
1140+
"torch.multinomial",
1141+
return_value=torch.arange(0, 10),
1142+
)
1143+
)
1144+
1145+
ics = gen_optimal_input_initial_conditions(
1146+
acq_function=jes,
1147+
bounds=bounds,
1148+
q=q,
1149+
num_restarts=num_restarts,
1150+
raw_samples=raw_samples,
1151+
options={"frac_random": 0.5},
1152+
)
1153+
1154+
mock_suggested_ics.assert_called_once()
1155+
mock_random_ics.assert_called_once()
1156+
mock_choose.assert_called_once()
1157+
1158+
expected_result = torch.cat(
1159+
(random_ics, suggested_ics.view(raw_samples // 2, q, 2)[0:2])
1160+
)
1161+
self.assertTrue(torch.equal(ics, expected_result))
1162+
10781163

10791164
class TestGenOneShotHVKGInitialConditions(BotorchTestCase):
10801165
def test_gen_one_shot_hvkg_initial_conditions(self):
@@ -1556,3 +1641,18 @@ def test_sample_points_around_best(self):
15561641
self.assertTrue(
15571642
((X_rnd.unsqueeze(0) == X_train.unsqueeze(1)).all(dim=-1)).sum() == 0
15581643
)
1644+
1645+
# providing suggestions of points to sample_around
1646+
suggestions = 1 + torch.rand(3, 20, **tkwargs)
1647+
X_rnd = sample_points_around_best(
1648+
acq_function=acqf,
1649+
n_discrete_points=5,
1650+
sigma=1e-3,
1651+
bounds=bounds,
1652+
prob_perturb=1e-8,
1653+
best_X=suggestions,
1654+
)
1655+
self.assertTrue(
1656+
((X_rnd.unsqueeze(0) == suggestions.unsqueeze(1)).all(dim=-1)).sum()
1657+
== 0
1658+
)

0 commit comments

Comments
 (0)