-
Notifications
You must be signed in to change notification settings - Fork 418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance & runtime improvements to info-theoretic acquisition functions (2/N) - AcqOpt initializer #2751
base: main
Are you sure you want to change the base?
Conversation
75bf7a0
to
ed81a46
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2751 +/- ##
=======================================
Coverage 99.99% 99.99%
=======================================
Files 203 203
Lines 18685 18726 +41
=======================================
+ Hits 18684 18725 +41
Misses 1 1 ☔ View full report in Codecov by Sentry. |
938d9be
to
f2db5ac
Compare
24781e9
to
c157b57
Compare
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
c157b57
to
211f79b
Compare
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
botorch/optim/initializers.py
Outdated
options: dict[str, bool | float | int] | None = None, | ||
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | ||
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | ||
): | ||
options = options or {} | ||
device = bounds.device | ||
if not hasattr(acq_function, "optimal_inputs"): | ||
raise AttributeError( | ||
"gen_optimal_input_initial_conditions can only be used with " | ||
"an AcquisitionFunction that has an optimal_inputs attribute." | ||
) | ||
frac_random: float = options.get("frac_random", 0.0) | ||
if not 0 <= frac_random <= 1: | ||
raise ValueError( | ||
f"frac_random must take on values in (0,1). Value: {frac_random}" | ||
) | ||
|
||
batch_limit = options.get("batch_limit") | ||
num_optima = acq_function.optimal_inputs.shape[:-1].numel() | ||
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) | ||
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) | ||
num_random = round(raw_samples * frac_random) | ||
if num_random > 0: | ||
X_rnd = sample_q_batches_from_polytope( | ||
n=num_random, | ||
q=q, | ||
bounds=bounds, | ||
n_burnin=options.get("n_burnin", 10000), | ||
n_thinning=options.get("n_thinning", 32), | ||
equality_constraints=equality_constraints, | ||
inequality_constraints=inequality_constraints, | ||
) | ||
X = torch.cat((X, X_rnd)) | ||
|
||
if num_random < raw_samples: | ||
X_perturbed = sample_points_around_best( | ||
acq_function=acq_function, | ||
n_discrete_points=q * (raw_samples - num_random), | ||
sigma=options.get("sample_around_best_sigma", 1e-2), | ||
bounds=bounds, | ||
best_X=suggestions, | ||
) | ||
X_perturbed = X_perturbed.view( | ||
raw_samples - num_random, q, bounds.shape[-1] | ||
).cpu() | ||
X = torch.cat((X, X_perturbed)) | ||
|
||
if options.get("sample_around_best", False): | ||
X_best = sample_points_around_best( | ||
acq_function=acq_function, | ||
n_discrete_points=q * raw_samples, | ||
sigma=options.get("sample_around_best_sigma", 1e-2), | ||
bounds=bounds, | ||
) | ||
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() | ||
X = torch.cat((X, X_best)) | ||
|
||
with torch.no_grad(): | ||
if batch_limit is None: | ||
batch_limit = X.shape[0] | ||
# Evaluate the acquisition function on `X_rnd` using `batch_limit` | ||
# sized chunks. | ||
acq_vals = torch.cat( | ||
[ | ||
acq_function(x_.to(device=device)).cpu() | ||
for x_ in X.split(split_size=batch_limit, dim=0) | ||
], | ||
dim=0, | ||
) | ||
idx = boltzmann_sample( | ||
function_values=acq_vals, | ||
num_samples=num_restarts, | ||
eta=options.get("eta", 2.0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By passing these individually rather than as a dict, we help static analysis tools (and people) see that the code isn't obviously wrong, and prevent unused options from being passed and silently dropped. That can be especially helpful in guarding against typos or when refactoring.
You could then update the call sites to pass **options
instead of options
-- personally I'd pass them individually everywhere, but it may be a matter of taste.
options: dict[str, bool | float | int] | None = None, | |
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | |
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | |
): | |
options = options or {} | |
device = bounds.device | |
if not hasattr(acq_function, "optimal_inputs"): | |
raise AttributeError( | |
"gen_optimal_input_initial_conditions can only be used with " | |
"an AcquisitionFunction that has an optimal_inputs attribute." | |
) | |
frac_random: float = options.get("frac_random", 0.0) | |
if not 0 <= frac_random <= 1: | |
raise ValueError( | |
f"frac_random must take on values in (0,1). Value: {frac_random}" | |
) | |
batch_limit = options.get("batch_limit") | |
num_optima = acq_function.optimal_inputs.shape[:-1].numel() | |
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) | |
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) | |
num_random = round(raw_samples * frac_random) | |
if num_random > 0: | |
X_rnd = sample_q_batches_from_polytope( | |
n=num_random, | |
q=q, | |
bounds=bounds, | |
n_burnin=options.get("n_burnin", 10000), | |
n_thinning=options.get("n_thinning", 32), | |
equality_constraints=equality_constraints, | |
inequality_constraints=inequality_constraints, | |
) | |
X = torch.cat((X, X_rnd)) | |
if num_random < raw_samples: | |
X_perturbed = sample_points_around_best( | |
acq_function=acq_function, | |
n_discrete_points=q * (raw_samples - num_random), | |
sigma=options.get("sample_around_best_sigma", 1e-2), | |
bounds=bounds, | |
best_X=suggestions, | |
) | |
X_perturbed = X_perturbed.view( | |
raw_samples - num_random, q, bounds.shape[-1] | |
).cpu() | |
X = torch.cat((X, X_perturbed)) | |
if options.get("sample_around_best", False): | |
X_best = sample_points_around_best( | |
acq_function=acq_function, | |
n_discrete_points=q * raw_samples, | |
sigma=options.get("sample_around_best_sigma", 1e-2), | |
bounds=bounds, | |
) | |
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() | |
X = torch.cat((X, X_best)) | |
with torch.no_grad(): | |
if batch_limit is None: | |
batch_limit = X.shape[0] | |
# Evaluate the acquisition function on `X_rnd` using `batch_limit` | |
# sized chunks. | |
acq_vals = torch.cat( | |
[ | |
acq_function(x_.to(device=device)).cpu() | |
for x_ in X.split(split_size=batch_limit, dim=0) | |
], | |
dim=0, | |
) | |
idx = boltzmann_sample( | |
function_values=acq_vals, | |
num_samples=num_restarts, | |
eta=options.get("eta", 2.0), | |
frac_random: float = 0.0, | |
batch_limit: int | None = None, | |
n_burnin: int = 10000, | |
n_thinning: int = 32, | |
sample_around_best: bool = False, | |
sample_around_best_sigma: float = 1e-2, | |
eta: float = 2.0, | |
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | |
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | |
): | |
options = options or {} | |
device = bounds.device | |
if not hasattr(acq_function, "optimal_inputs"): | |
raise AttributeError( | |
"gen_optimal_input_initial_conditions can only be used with " | |
"an AcquisitionFunction that has an optimal_inputs attribute." | |
) | |
frac_random: float = options.get("frac_random", 0.0) | |
if not 0 <= frac_random <= 1: | |
raise ValueError( | |
f"frac_random must take on values in (0,1). Value: {frac_random}" | |
) | |
batch_limit = options.get("batch_limit") | |
num_optima = acq_function.optimal_inputs.shape[:-1].numel() | |
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) | |
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) | |
num_random = round(raw_samples * frac_random) | |
if num_random > 0: | |
X_rnd = sample_q_batches_from_polytope( | |
n=num_random, | |
q=q, | |
bounds=bounds, | |
n_burnin=options.get("n_burnin", 10000), | |
n_thinning=options.get("n_thinning", 32), | |
equality_constraints=equality_constraints, | |
inequality_constraints=inequality_constraints, | |
) | |
X = torch.cat((X, X_rnd)) | |
if num_random < raw_samples: | |
X_perturbed = sample_points_around_best( | |
acq_function=acq_function, | |
n_discrete_points=q * (raw_samples - num_random), | |
sigma=options.get("sample_around_best_sigma", 1e-2), | |
bounds=bounds, | |
best_X=suggestions, | |
) | |
X_perturbed = X_perturbed.view( | |
raw_samples - num_random, q, bounds.shape[-1] | |
).cpu() | |
X = torch.cat((X, X_perturbed)) | |
if options.get("sample_around_best", False): | |
X_best = sample_points_around_best( | |
acq_function=acq_function, | |
n_discrete_points=q * raw_samples, | |
sigma=options.get("sample_around_best_sigma", 1e-2), | |
bounds=bounds, | |
) | |
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() | |
X = torch.cat((X, X_best)) | |
with torch.no_grad(): | |
if batch_limit is None: | |
batch_limit = X.shape[0] | |
# Evaluate the acquisition function on `X_rnd` using `batch_limit` | |
# sized chunks. | |
acq_vals = torch.cat( | |
[ | |
acq_function(x_.to(device=device)).cpu() | |
for x_ in X.split(split_size=batch_limit, dim=0) | |
], | |
dim=0, | |
) | |
idx = boltzmann_sample( | |
function_values=acq_vals, | |
num_samples=num_restarts, | |
eta=options.get("eta", 2.0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'm okay doing it this way! It just seems that changing ic_generator
alone wouldn't suffice if one (for some reason) wanted to change between them since they would be inconsistent?
X = torch.cat((X, X_rnd)) | ||
|
||
if num_random < raw_samples: | ||
X_perturbed = sample_points_around_best( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit nonintuitive that we do this even when sample_around_best
is False
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly! My though was, since it is not actually sampling around the incumbent but around the sampled optima, I could keep it and re-use its logic. I tried to mimic the KG logic for it, and that uses frac_random
for a similar reason.
Thanks for this! I'm looking forward to seeing the plots. |
211f79b
to
1ec824c
Compare
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
I have not quite figured out why the test coverage is not there, since I thought I addressed it today. I will also figure out the conflicts ASAP! |
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
improve info-theoretic acquisition functions.
1ec824c
to
f1d4ec0
Compare
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | ||
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | ||
): | ||
r"""Generate a batch of initial conditions for random-restart optimziation of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r"""Generate a batch of initial conditions for random-restart optimziation of | |
r"""Generate a batch of initial conditions for random-restart optimization of |
fraction of initial samples (by default: 100%) are drawn as perturbations around | ||
`acq.optimal_inputs`. On average, this drastically decreases the runtime of | ||
acquisition function optimization and yields higher-valued candidates by acquisition | ||
function value. See https://github.com/pytorch/botorch/pull/2751 for more info. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function value. See https://github.com/pytorch/botorch/pull/2751 for more info. | |
function value. See https://github.com/pytorch/botorch/pull/2751 for more info, relative to [...]. |
best_X = X[is_pareto] | ||
else: | ||
if f_pred.shape[-1] == 1: | ||
f_pred = f_pred.squeeze(-1) | ||
n_best = max(1, round(X.shape[0] * best_pct / 100)) | ||
# the view() is to ensure that best_idcs is not a scalar tensor | ||
best_idcs = torch.topk(f_pred, n_best).indices.view(-1) | ||
best_X = X[best_idcs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
best_X
needs to be on the same device as bounds
. I'm not sure which branch it's coming from, but I'm seeing a failure when test_gen_optimal_input_initial_conditions
is run with tensors on CUDA.
@@ -576,12 +578,20 @@ def get_optimal_samples( | |||
sample_transform = None | |||
|
|||
paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima])) | |||
suggested_points = prune_inferior_points( | |||
model=model, | |||
X=model.train_inputs[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work if the model is a ModelListGP
, where `train_inputs is a list of tuples of tensors rather than a tuple of tensors: https://github.com/cornellius-gp/gpytorch/blob/b017b9c3fe4de526f7a2243ce12ce2305862c90b/gpytorch/models/model_list.py#L83-L86
Any thoughts on what we should do there?
A series of improvements directed towards improving the performance of PES & JES, as well as their MultiObj counterparts.
This PR adds an initializer for the acquisition function optimization, which drastically speeds up the number of required forward passes from ~150-250 --> ~25 by providing suggestions close to the sampled optima obtained during acquisition function construction.
@esantorella

Moreover, better acquisition function values are found (PR 1's BO loop, but both acq opts are run in parallel):
Moreover, it is a lot faster:

This does not always improve performance, however (PR1 is more local due to sample_around_best dominating candidate generation, which is generally good):

Lastly, a nice comp to LogNEI with the introduced mods:

Moreover, they are now much closer in terms of runtime:

And here's the allocation between posterior sampling time and acq optimization time.

So apart from Michalewicz, it does pretty good now!
Related PRs
Previous one