@@ -468,6 +468,90 @@ def gen_batch_initial_conditions(
468
468
return batch_initial_conditions
469
469
470
470
471
+ def gen_optimal_input_initial_conditions (
472
+ acq_function : AcquisitionFunction ,
473
+ bounds : Tensor ,
474
+ q : int ,
475
+ num_restarts : int ,
476
+ raw_samples : int ,
477
+ fixed_features : dict [int , float ] | None = None ,
478
+ options : dict [str , bool | float | int ] | None = None ,
479
+ inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
480
+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
481
+ ):
482
+ device = bounds .device
483
+ if not hasattr (acq_function , "optimal_inputs" ):
484
+ raise AttributeError (
485
+ "gen_optimal_input_initial_conditions can only be used with "
486
+ "an AcquisitionFunction that has an optimal_inputs attribute."
487
+ )
488
+ frac_random : float = options .get ("frac_random" , 0.0 )
489
+ if not 0 <= frac_random <= 1 :
490
+ raise ValueError (
491
+ f"frac_random must take on values in (0,1). Value: { frac_random } "
492
+ )
493
+
494
+ batch_limit = options .get ("batch_limit" )
495
+ num_optima = acq_function .optimal_inputs .shape [:- 1 ].numel ()
496
+ suggestions = acq_function .optimal_inputs .reshape (num_optima , - 1 )
497
+ X = torch .empty (0 , q , bounds .shape [1 ], dtype = bounds .dtype )
498
+ num_random = round (raw_samples * frac_random )
499
+ if num_random > 0 :
500
+ X_rnd = sample_q_batches_from_polytope (
501
+ n = num_random ,
502
+ q = q ,
503
+ bounds = bounds ,
504
+ n_burnin = options .get ("n_burnin" , 10000 ),
505
+ n_thinning = options .get ("n_thinning" , 32 ),
506
+ equality_constraints = equality_constraints ,
507
+ inequality_constraints = inequality_constraints ,
508
+ )
509
+ X = torch .cat ((X , X_rnd ))
510
+
511
+ if num_random < raw_samples :
512
+ X_perturbed = sample_points_around_best (
513
+ acq_function = acq_function ,
514
+ n_discrete_points = q * (raw_samples - num_random ),
515
+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
516
+ bounds = bounds ,
517
+ best_X = suggestions ,
518
+ )
519
+ X_perturbed = X_perturbed .view (
520
+ raw_samples - num_random , q , bounds .shape [- 1 ]
521
+ ).cpu ()
522
+ X = torch .cat ((X , X_perturbed ))
523
+
524
+ if options .get ("sample_around_best" , False ):
525
+ X_best = sample_points_around_best (
526
+ acq_function = acq_function ,
527
+ n_discrete_points = q * raw_samples ,
528
+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
529
+ bounds = bounds ,
530
+ )
531
+ X_best = X_best .view (raw_samples , q , bounds .shape [- 1 ]).cpu ()
532
+ X = torch .cat ((X , X_best ))
533
+
534
+ with torch .no_grad ():
535
+ if batch_limit is None :
536
+ batch_limit = X .shape [0 ]
537
+ # Evaluate the acquisition function on `X_rnd` using `batch_limit`
538
+ # sized chunks.
539
+ acq_vals = torch .cat (
540
+ [
541
+ acq_function (x_ .to (device = device )).cpu ()
542
+ for x_ in X .split (split_size = batch_limit , dim = 0 )
543
+ ],
544
+ dim = 0 ,
545
+ )
546
+ idx = boltzmann_sample (
547
+ function_values = acq_vals ,
548
+ num_samples = num_restarts ,
549
+ eta = options .get ("eta" , 2.0 ),
550
+ )
551
+ # set the respective initial conditions to the sampled optimizers
552
+ return X [idx ]
553
+
554
+
471
555
def gen_one_shot_kg_initial_conditions (
472
556
acq_function : qKnowledgeGradient ,
473
557
bounds : Tensor ,
@@ -602,59 +686,59 @@ def gen_one_shot_hvkg_initial_conditions(
602
686
) -> Tensor | None :
603
687
r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient.
604
688
605
- This function generates initial conditions for optimizing one-shot HVKG using
606
- the hypervolume maximizing set (of fixed size) under the posterior mean.
607
- Intutively, the hypervolume maximizing set of the fantasized posterior mean
608
- will often be close to a hypervolume maximizing set under the current posterior
609
- mean. This function uses that fact to generate the initial conditions
610
- for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
611
- options) of the restarts are generated by learning the hypervolume maximizing sets
612
- under the current posterior mean, where each hypervolume maximizing set is
613
- obtained from maximizing the hypervolume from a different starting point. Given
614
- a hypervolume maximizing set, the `q` candidate points are selected using to the
615
- standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
616
- hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
617
- as well as all `q` candidate points are chosen according to the standard
618
- initialization strategy in `gen_batch_initial_conditions`.
619
-
620
- Args:
621
- acq_function: The qKnowledgeGradient instance to be optimized.
622
- bounds: A `2 x d` tensor of lower and upper bounds for each column of
623
- task features.
624
- q: The number of candidates to consider.
625
- num_restarts: The number of starting points for multistart acquisition
626
- function optimization.
627
- raw_samples: The number of raw samples to consider in the initialization
628
- heuristic.
629
- fixed_features: A map `{feature_index: value}` for features that
630
- should be fixed to a particular value during generation.
631
- options: Options for initial condition generation. These contain all
632
- settings for the standard heuristic initialization from
633
- `gen_batch_initial_conditions`. In addition, they contain
634
- `frac_random` (the fraction of fully random fantasy points),
635
- `num_inner_restarts` and `raw_inner_samples` (the number of random
636
- restarts and raw samples for solving the posterior objective
637
- maximization problem, respectively) and `eta` (temperature parameter
638
- for sampling heuristic from posterior objective maximizers).
639
- inequality constraints: A list of tuples (indices, coefficients, rhs),
640
- with each tuple encoding an inequality constraint of the form
641
- `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
642
- equality constraints: A list of tuples (indices, coefficients, rhs),
643
- with each tuple encoding an inequality constraint of the form
644
- `\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
645
-
646
- Returns:
647
- A `num_restarts x q' x d` tensor that can be used as initial conditions
648
- for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
649
- of points (candidate points plus fantasy points).
650
-
651
- Example:
652
- >>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
653
- >>> bounds = torch.tensor([[0., 0.], [1., 1.]])
654
- >>> Xinit = gen_one_shot_hvkg_initial_conditions(
655
- >>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
656
- >>> options={"frac_random": 0.25},
657
- >>> )
689
+ This function generates initial conditions for optimizing one-shot HVKG using
690
+ the hypervolume maximizing set (of fixed size) under the posterior mean.
691
+ Intutively, the hypervolume maximizing set of the fantasized posterior mean
692
+ will often be close to a hypervolume maximizing set under the current posterior
693
+ mean. This function uses that fact to generate the initial conditions
694
+ for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
695
+ options) of the restarts are generated by learning the hypervolume maximizing sets
696
+ under the current posterior mean, where each hypervolume maximizing set is
697
+ obtained from maximizing the hypervolume from a different starting point. Given
698
+ a hypervolume maximizing set, the `q` candidate points are selected using to the
699
+ standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
700
+ hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
701
+ as well as all `q` candidate points are chosen according to the standard
702
+ initialization strategy in `gen_batch_initial_conditions`.
703
+
704
+ Args:
705
+ acq_function: The qKnowledgeGradient instance to be optimized.
706
+ bounds: A `2 x d` tensor of lower and upper bounds for each column of
707
+ task features.
708
+ q: The number of candidates to consider.
709
+ num_restarts: The number of starting points for multistart acquisition
710
+ function optimization.
711
+ raw_samples: The number of raw samples to consider in the initialization
712
+ heuristic.
713
+ fixed_features: A map `{feature_index: value}` for features that
714
+ should be fixed to a particular value during generation.
715
+ options: Options for initial condition generation. These contain all
716
+ settings for the standard heuristic initialization from
717
+ `gen_batch_initial_conditions`. In addition, they contain
718
+ `frac_random` (the fraction of fully random fantasy points),
719
+ `num_inner_restarts` and `raw_inner_samples` (the number of random
720
+ restarts and raw samples for solving the posterior objective
721
+ maximization problem, respectively) and `eta` (temperature parameter
722
+ for sampling heuristic from posterior objective maximizers).
723
+ inequality constraints: A list of tuples (indices, coefficients, rhs),
724
+ with each tuple encoding an inequality constraint of the form
725
+ `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
726
+ equality constraints: A list of tuples (indices, coefficients, rhs),
727
+ with each tuple encoding an inequality constraint of the form
728
+ `\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
729
+
730
+ Returns:
731
+ A `num_restarts x q' x d` tensor that can be used as initial conditions
732
+ for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
733
+ of points (candidate points plus fantasy points).
734
+
735
+ gen_batch_initial_conditions Example:
736
+ >>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
737
+ >>> bounds = torch.tensor([[0., 0.], [1., 1.]])
738
+ >>> Xinit = gen_one_shot_hvkg_initial_conditions(
739
+ >>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
740
+ >>> options={"frac_random": 0.25},
741
+ >>> )
658
742
"""
659
743
from botorch .optim .optimize import optimize_acqf
660
744
@@ -1136,6 +1220,7 @@ def sample_points_around_best(
1136
1220
best_pct : float = 5.0 ,
1137
1221
subset_sigma : float = 1e-1 ,
1138
1222
prob_perturb : float | None = None ,
1223
+ best_X : Tensor | None = None ,
1139
1224
) -> Tensor | None :
1140
1225
r"""Find best points and sample nearby points.
1141
1226
@@ -1154,60 +1239,62 @@ def sample_points_around_best(
1154
1239
An optional `n_discrete_points x d`-dim tensor containing the
1155
1240
sampled points. This is None if no baseline points are found.
1156
1241
"""
1157
- X = get_X_baseline (acq_function = acq_function )
1158
- if X is None :
1159
- return
1160
- with torch .no_grad ():
1161
- try :
1162
- posterior = acq_function .model .posterior (X )
1163
- except AttributeError :
1164
- warnings .warn (
1165
- "Failed to sample around previous best points." ,
1166
- BotorchWarning ,
1167
- stacklevel = 3 ,
1168
- )
1242
+ if best_X is None :
1243
+ X = get_X_baseline (acq_function = acq_function )
1244
+ if X is None :
1169
1245
return
1170
- mean = posterior .mean
1171
- while mean .ndim > 2 :
1172
- # take average over batch dims
1173
- mean = mean .mean (dim = 0 )
1174
- try :
1175
- f_pred = acq_function .objective (mean )
1176
- # Some acquisition functions do not have an objective
1177
- # and for some acquisition functions the objective is None
1178
- except (AttributeError , TypeError ):
1179
- f_pred = mean
1180
- if hasattr (acq_function , "maximize" ):
1181
- # make sure that the optimiztaion direction is set properly
1182
- if not acq_function .maximize :
1183
- f_pred = - f_pred
1184
- try :
1185
- # handle constraints for EHVI-based acquisition functions
1186
- constraints = acq_function .constraints
1187
- if constraints is not None :
1188
- neg_violation = - torch .stack (
1189
- [c (mean ).clamp_min (0.0 ) for c in constraints ], dim = - 1
1190
- ).sum (dim = - 1 )
1191
- feas = neg_violation == 0
1192
- if feas .any ():
1193
- f_pred [~ feas ] = float ("-inf" )
1194
- else :
1195
- # set objective equal to negative violation
1196
- f_pred = neg_violation
1197
- except AttributeError :
1198
- pass
1199
- if f_pred .ndim == mean .ndim and f_pred .shape [- 1 ] > 1 :
1200
- # multi-objective
1201
- # find pareto set
1202
- is_pareto = is_non_dominated (f_pred )
1203
- best_X = X [is_pareto ]
1204
- else :
1205
- if f_pred .shape [- 1 ] == 1 :
1206
- f_pred = f_pred .squeeze (- 1 )
1207
- n_best = max (1 , round (X .shape [0 ] * best_pct / 100 ))
1208
- # the view() is to ensure that best_idcs is not a scalar tensor
1209
- best_idcs = torch .topk (f_pred , n_best ).indices .view (- 1 )
1210
- best_X = X [best_idcs ]
1246
+ with torch .no_grad ():
1247
+ try :
1248
+ posterior = acq_function .model .posterior (X )
1249
+ except AttributeError :
1250
+ warnings .warn (
1251
+ "Failed to sample around previous best points." ,
1252
+ BotorchWarning ,
1253
+ stacklevel = 3 ,
1254
+ )
1255
+ return
1256
+ mean = posterior .mean
1257
+ while mean .ndim > 2 :
1258
+ # take average over batch dims
1259
+ mean = mean .mean (dim = 0 )
1260
+ try :
1261
+ f_pred = acq_function .objective (mean )
1262
+ # Some acquisition functions do not have an objective
1263
+ # and for some acquisition functions the objective is None
1264
+ except (AttributeError , TypeError ):
1265
+ f_pred = mean
1266
+ if hasattr (acq_function , "maximize" ):
1267
+ # make sure that the optimiztaion direction is set properly
1268
+ if not acq_function .maximize :
1269
+ f_pred = - f_pred
1270
+ try :
1271
+ # handle constraints for EHVI-based acquisition functions
1272
+ constraints = acq_function .constraints
1273
+ if constraints is not None :
1274
+ neg_violation = - torch .stack (
1275
+ [c (mean ).clamp_min (0.0 ) for c in constraints ], dim = - 1
1276
+ ).sum (dim = - 1 )
1277
+ feas = neg_violation == 0
1278
+ if feas .any ():
1279
+ f_pred [~ feas ] = float ("-inf" )
1280
+ else :
1281
+ # set objective equal to negative violation
1282
+ f_pred = neg_violation
1283
+ except AttributeError :
1284
+ pass
1285
+ if f_pred .ndim == mean .ndim and f_pred .shape [- 1 ] > 1 :
1286
+ # multi-objective
1287
+ # find pareto set
1288
+ is_pareto = is_non_dominated (f_pred )
1289
+ best_X = X [is_pareto ]
1290
+ else :
1291
+ if f_pred .shape [- 1 ] == 1 :
1292
+ f_pred = f_pred .squeeze (- 1 )
1293
+ n_best = max (1 , round (X .shape [0 ] * best_pct / 100 ))
1294
+ # the view() is to ensure that best_idcs is not a scalar tensor
1295
+ best_idcs = torch .topk (f_pred , n_best ).indices .view (- 1 )
1296
+ best_X = X [best_idcs ]
1297
+
1211
1298
use_perturbed_sampling = best_X .shape [- 1 ] >= 20 or prob_perturb is not None
1212
1299
n_trunc_normal_points = (
1213
1300
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
0 commit comments