Skip to content

Commit 24781e9

Browse files
committed
Added ES optimization initializer
1 parent f4f01bf commit 24781e9

File tree

4 files changed

+288
-108
lines changed

4 files changed

+288
-108
lines changed

botorch/optim/initializers.py

+193-106
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,90 @@ def gen_batch_initial_conditions(
468468
return batch_initial_conditions
469469

470470

471+
def gen_optimal_input_initial_conditions(
472+
acq_function: AcquisitionFunction,
473+
bounds: Tensor,
474+
q: int,
475+
num_restarts: int,
476+
raw_samples: int,
477+
fixed_features: dict[int, float] | None = None,
478+
options: dict[str, bool | float | int] | None = None,
479+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
480+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
481+
):
482+
device = bounds.device
483+
if not hasattr(acq_function, "optimal_inputs"):
484+
raise AttributeError(
485+
"gen_optimal_input_initial_conditions can only be used with "
486+
"an AcquisitionFunction that has an optimal_inputs attribute."
487+
)
488+
frac_random: float = options.get("frac_random", 0.0)
489+
if not 0 <= frac_random <= 1:
490+
raise ValueError(
491+
f"frac_random must take on values in (0,1). Value: {frac_random}"
492+
)
493+
494+
batch_limit = options.get("batch_limit")
495+
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
496+
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
497+
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
498+
num_random = round(raw_samples * frac_random)
499+
if num_random > 0:
500+
X_rnd = sample_q_batches_from_polytope(
501+
n=num_random,
502+
q=q,
503+
bounds=bounds,
504+
n_burnin=options.get("n_burnin", 10000),
505+
n_thinning=options.get("n_thinning", 32),
506+
equality_constraints=equality_constraints,
507+
inequality_constraints=inequality_constraints,
508+
)
509+
X = torch.cat((X, X_rnd))
510+
511+
if num_random < raw_samples:
512+
X_perturbed = sample_points_around_best(
513+
acq_function=acq_function,
514+
n_discrete_points=q * (raw_samples - num_random),
515+
sigma=options.get("sample_around_best_sigma", 1e-2),
516+
bounds=bounds,
517+
best_X=suggestions,
518+
)
519+
X_perturbed = X_perturbed.view(
520+
raw_samples - num_random, q, bounds.shape[-1]
521+
).cpu()
522+
X = torch.cat((X, X_perturbed))
523+
524+
if options.get("sample_around_best", False):
525+
X_best = sample_points_around_best(
526+
acq_function=acq_function,
527+
n_discrete_points=q * raw_samples,
528+
sigma=options.get("sample_around_best_sigma", 1e-2),
529+
bounds=bounds,
530+
)
531+
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
532+
X = torch.cat((X, X_best))
533+
534+
with torch.no_grad():
535+
if batch_limit is None:
536+
batch_limit = X.shape[0]
537+
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
538+
# sized chunks.
539+
acq_vals = torch.cat(
540+
[
541+
acq_function(x_.to(device=device)).cpu()
542+
for x_ in X.split(split_size=batch_limit, dim=0)
543+
],
544+
dim=0,
545+
)
546+
idx = boltzmann_sample(
547+
function_values=acq_vals,
548+
num_samples=num_restarts,
549+
eta=options.get("eta", 2.0),
550+
)
551+
# set the respective initial conditions to the sampled optimizers
552+
return X[idx]
553+
554+
471555
def gen_one_shot_kg_initial_conditions(
472556
acq_function: qKnowledgeGradient,
473557
bounds: Tensor,
@@ -602,59 +686,59 @@ def gen_one_shot_hvkg_initial_conditions(
602686
) -> Tensor | None:
603687
r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient.
604688
605-
This function generates initial conditions for optimizing one-shot HVKG using
606-
the hypervolume maximizing set (of fixed size) under the posterior mean.
607-
Intutively, the hypervolume maximizing set of the fantasized posterior mean
608-
will often be close to a hypervolume maximizing set under the current posterior
609-
mean. This function uses that fact to generate the initial conditions
610-
for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
611-
options) of the restarts are generated by learning the hypervolume maximizing sets
612-
under the current posterior mean, where each hypervolume maximizing set is
613-
obtained from maximizing the hypervolume from a different starting point. Given
614-
a hypervolume maximizing set, the `q` candidate points are selected using to the
615-
standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
616-
hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
617-
as well as all `q` candidate points are chosen according to the standard
618-
initialization strategy in `gen_batch_initial_conditions`.
619-
620-
Args:
621-
acq_function: The qKnowledgeGradient instance to be optimized.
622-
bounds: A `2 x d` tensor of lower and upper bounds for each column of
623-
task features.
624-
q: The number of candidates to consider.
625-
num_restarts: The number of starting points for multistart acquisition
626-
function optimization.
627-
raw_samples: The number of raw samples to consider in the initialization
628-
heuristic.
629-
fixed_features: A map `{feature_index: value}` for features that
630-
should be fixed to a particular value during generation.
631-
options: Options for initial condition generation. These contain all
632-
settings for the standard heuristic initialization from
633-
`gen_batch_initial_conditions`. In addition, they contain
634-
`frac_random` (the fraction of fully random fantasy points),
635-
`num_inner_restarts` and `raw_inner_samples` (the number of random
636-
restarts and raw samples for solving the posterior objective
637-
maximization problem, respectively) and `eta` (temperature parameter
638-
for sampling heuristic from posterior objective maximizers).
639-
inequality constraints: A list of tuples (indices, coefficients, rhs),
640-
with each tuple encoding an inequality constraint of the form
641-
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
642-
equality constraints: A list of tuples (indices, coefficients, rhs),
643-
with each tuple encoding an inequality constraint of the form
644-
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
645-
646-
Returns:
647-
A `num_restarts x q' x d` tensor that can be used as initial conditions
648-
for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
649-
of points (candidate points plus fantasy points).
650-
651-
Example:
652-
>>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
653-
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
654-
>>> Xinit = gen_one_shot_hvkg_initial_conditions(
655-
>>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
656-
>>> options={"frac_random": 0.25},
657-
>>> )
689+
This function generates initial conditions for optimizing one-shot HVKG using
690+
the hypervolume maximizing set (of fixed size) under the posterior mean.
691+
Intutively, the hypervolume maximizing set of the fantasized posterior mean
692+
will often be close to a hypervolume maximizing set under the current posterior
693+
mean. This function uses that fact to generate the initial conditions
694+
for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
695+
options) of the restarts are generated by learning the hypervolume maximizing sets
696+
under the current posterior mean, where each hypervolume maximizing set is
697+
obtained from maximizing the hypervolume from a different starting point. Given
698+
a hypervolume maximizing set, the `q` candidate points are selected using to the
699+
standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
700+
hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
701+
as well as all `q` candidate points are chosen according to the standard
702+
initialization strategy in `gen_batch_initial_conditions`.
703+
704+
Args:
705+
acq_function: The qKnowledgeGradient instance to be optimized.
706+
bounds: A `2 x d` tensor of lower and upper bounds for each column of
707+
task features.
708+
q: The number of candidates to consider.
709+
num_restarts: The number of starting points for multistart acquisition
710+
function optimization.
711+
raw_samples: The number of raw samples to consider in the initialization
712+
heuristic.
713+
fixed_features: A map `{feature_index: value}` for features that
714+
should be fixed to a particular value during generation.
715+
options: Options for initial condition generation. These contain all
716+
settings for the standard heuristic initialization from
717+
`gen_batch_initial_conditions`. In addition, they contain
718+
`frac_random` (the fraction of fully random fantasy points),
719+
`num_inner_restarts` and `raw_inner_samples` (the number of random
720+
restarts and raw samples for solving the posterior objective
721+
maximization problem, respectively) and `eta` (temperature parameter
722+
for sampling heuristic from posterior objective maximizers).
723+
inequality constraints: A list of tuples (indices, coefficients, rhs),
724+
with each tuple encoding an inequality constraint of the form
725+
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
726+
equality constraints: A list of tuples (indices, coefficients, rhs),
727+
with each tuple encoding an inequality constraint of the form
728+
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
729+
730+
Returns:
731+
A `num_restarts x q' x d` tensor that can be used as initial conditions
732+
for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
733+
of points (candidate points plus fantasy points).
734+
735+
gen_batch_initial_conditions Example:
736+
>>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
737+
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
738+
>>> Xinit = gen_one_shot_hvkg_initial_conditions(
739+
>>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
740+
>>> options={"frac_random": 0.25},
741+
>>> )
658742
"""
659743
from botorch.optim.optimize import optimize_acqf
660744

@@ -1136,6 +1220,7 @@ def sample_points_around_best(
11361220
best_pct: float = 5.0,
11371221
subset_sigma: float = 1e-1,
11381222
prob_perturb: float | None = None,
1223+
best_X: Tensor | None = None,
11391224
) -> Tensor | None:
11401225
r"""Find best points and sample nearby points.
11411226
@@ -1154,60 +1239,62 @@ def sample_points_around_best(
11541239
An optional `n_discrete_points x d`-dim tensor containing the
11551240
sampled points. This is None if no baseline points are found.
11561241
"""
1157-
X = get_X_baseline(acq_function=acq_function)
1158-
if X is None:
1159-
return
1160-
with torch.no_grad():
1161-
try:
1162-
posterior = acq_function.model.posterior(X)
1163-
except AttributeError:
1164-
warnings.warn(
1165-
"Failed to sample around previous best points.",
1166-
BotorchWarning,
1167-
stacklevel=3,
1168-
)
1242+
if best_X is None:
1243+
X = get_X_baseline(acq_function=acq_function)
1244+
if X is None:
11691245
return
1170-
mean = posterior.mean
1171-
while mean.ndim > 2:
1172-
# take average over batch dims
1173-
mean = mean.mean(dim=0)
1174-
try:
1175-
f_pred = acq_function.objective(mean)
1176-
# Some acquisition functions do not have an objective
1177-
# and for some acquisition functions the objective is None
1178-
except (AttributeError, TypeError):
1179-
f_pred = mean
1180-
if hasattr(acq_function, "maximize"):
1181-
# make sure that the optimiztaion direction is set properly
1182-
if not acq_function.maximize:
1183-
f_pred = -f_pred
1184-
try:
1185-
# handle constraints for EHVI-based acquisition functions
1186-
constraints = acq_function.constraints
1187-
if constraints is not None:
1188-
neg_violation = -torch.stack(
1189-
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1190-
).sum(dim=-1)
1191-
feas = neg_violation == 0
1192-
if feas.any():
1193-
f_pred[~feas] = float("-inf")
1194-
else:
1195-
# set objective equal to negative violation
1196-
f_pred = neg_violation
1197-
except AttributeError:
1198-
pass
1199-
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1200-
# multi-objective
1201-
# find pareto set
1202-
is_pareto = is_non_dominated(f_pred)
1203-
best_X = X[is_pareto]
1204-
else:
1205-
if f_pred.shape[-1] == 1:
1206-
f_pred = f_pred.squeeze(-1)
1207-
n_best = max(1, round(X.shape[0] * best_pct / 100))
1208-
# the view() is to ensure that best_idcs is not a scalar tensor
1209-
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1210-
best_X = X[best_idcs]
1246+
with torch.no_grad():
1247+
try:
1248+
posterior = acq_function.model.posterior(X)
1249+
except AttributeError:
1250+
warnings.warn(
1251+
"Failed to sample around previous best points.",
1252+
BotorchWarning,
1253+
stacklevel=3,
1254+
)
1255+
return
1256+
mean = posterior.mean
1257+
while mean.ndim > 2:
1258+
# take average over batch dims
1259+
mean = mean.mean(dim=0)
1260+
try:
1261+
f_pred = acq_function.objective(mean)
1262+
# Some acquisition functions do not have an objective
1263+
# and for some acquisition functions the objective is None
1264+
except (AttributeError, TypeError):
1265+
f_pred = mean
1266+
if hasattr(acq_function, "maximize"):
1267+
# make sure that the optimiztaion direction is set properly
1268+
if not acq_function.maximize:
1269+
f_pred = -f_pred
1270+
try:
1271+
# handle constraints for EHVI-based acquisition functions
1272+
constraints = acq_function.constraints
1273+
if constraints is not None:
1274+
neg_violation = -torch.stack(
1275+
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1276+
).sum(dim=-1)
1277+
feas = neg_violation == 0
1278+
if feas.any():
1279+
f_pred[~feas] = float("-inf")
1280+
else:
1281+
# set objective equal to negative violation
1282+
f_pred = neg_violation
1283+
except AttributeError:
1284+
pass
1285+
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1286+
# multi-objective
1287+
# find pareto set
1288+
is_pareto = is_non_dominated(f_pred)
1289+
best_X = X[is_pareto]
1290+
else:
1291+
if f_pred.shape[-1] == 1:
1292+
f_pred = f_pred.squeeze(-1)
1293+
n_best = max(1, round(X.shape[0] * best_pct / 100))
1294+
# the view() is to ensure that best_idcs is not a scalar tensor
1295+
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1296+
best_X = X[best_idcs]
1297+
12111298
use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
12121299
n_trunc_normal_points = (
12131300
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points

botorch/optim/optimize.py

+7
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
AcquisitionFunction,
2121
OneShotAcquisitionFunction,
2222
)
23+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
2324
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2425
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
2526
qHypervolumeKnowledgeGradient,
2627
)
28+
from botorch.acquisition.predictive_entropy_search import qPredictiveEntropySearch
2729
from botorch.exceptions import InputDataError, UnsupportedError
2830
from botorch.exceptions.errors import CandidateGenerationError
2931
from botorch.exceptions.warnings import OptimizationWarning
@@ -33,6 +35,7 @@
3335
gen_batch_initial_conditions,
3436
gen_one_shot_hvkg_initial_conditions,
3537
gen_one_shot_kg_initial_conditions,
38+
gen_optimal_input_initial_conditions,
3639
TGenInitialConditions,
3740
)
3841
from botorch.optim.stopping import ExpMAStoppingCriterion
@@ -174,6 +177,10 @@ def get_ic_generator(self) -> TGenInitialConditions:
174177
return gen_one_shot_kg_initial_conditions
175178
elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient):
176179
return gen_one_shot_hvkg_initial_conditions
180+
elif isinstance(
181+
self.acq_function, (qJointEntropySearch, qPredictiveEntropySearch)
182+
):
183+
return gen_optimal_input_initial_conditions
177184
return gen_batch_initial_conditions
178185

179186

botorch/utils/sampling.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
BotorchTensorDimensionError,
3636
InfeasibilityError,
3737
)
38+
from botorch.utils.transforms import standardize
3839
from botorch.exceptions.warnings import UserInputWarning
3940
from botorch.sampling.qmc import NormalQMCEngine
4041

41-
from botorch.utils.transforms import normalize, standardize, unnormalize
42+
from botorch.utils.transforms import normalize, unnormalize
4243
from scipy.spatial import Delaunay, HalfspaceIntersection
4344
from torch import LongTensor, Tensor
4445
from torch.distributions import Normal
@@ -1123,7 +1124,7 @@ def boltzmann_sample(
11231124
while torch.isinf(weights).any():
11241125
eta *= temp_decrease
11251126
weights = torch.exp(eta * norm_weights)
1126-
1127+
11271128
# squeeze in case of m = 1 (mono-output provided as batch_size x N x 1)
11281129
return batched_multinomial(
11291130
weights=weights.squeeze(-1), num_samples=num_samples, replacement=replacement

0 commit comments

Comments
 (0)