Skip to content

Commit

Permalink
feat: Avoid unnecessary function evaluation in Nelder-Mead
Browse files Browse the repository at this point in the history
  • Loading branch information
pnevyk committed Feb 7, 2022
1 parent 34a7fca commit fb8e90e
Showing 1 changed file with 48 additions and 24 deletions.
72 changes: 48 additions & 24 deletions src/solver/nelder_mead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ where
{
options: NelderMeadOptions<F>,
fx: OVector<F::Scalar, F::Dim>,
fx_best: OVector<F::Scalar, F::Dim>,
scale: OVector<F::Scalar, F::Dim>,
centroid: OVector<F::Scalar, F::Dim>,
reflection: OVector<F::Scalar, F::Dim>,
Expand Down Expand Up @@ -155,6 +156,7 @@ where
Self {
options,
fx: OVector::zeros_generic(f.dim(), U1::name()),
fx_best: OVector::zeros_generic(f.dim(), U1::name()),
scale: OVector::from_iterator_generic(f.dim(), U1::name(), scale_iter),
centroid: OVector::zeros_generic(f.dim(), U1::name()),
reflection: OVector::zeros_generic(f.dim(), U1::name()),
Expand Down Expand Up @@ -202,6 +204,8 @@ where
} = self.options;

let Self {
fx,
fx_best,
scale,
simplex,
errors,
Expand All @@ -221,7 +225,8 @@ where

// It is important to return early on error before the point is
// added to the simplex.
errors.push(f.apply(x)?);
let mut error_best = f.apply(x)?;
errors.push(error_best);
simplex.push(x.clone_owned());

for j in 0..n {
Expand All @@ -246,6 +251,11 @@ where
}
};

if error < error_best {
fx_best.copy_from(fx);
error_best = error;
}

errors.push(error);
simplex.push(xi);
}
Expand All @@ -260,7 +270,9 @@ where
}

sort_perm.extend(0..=n);
sort_perm.sort_unstable_by(|a, b| {
// Stable sort is important for sort_perm[0] being consistent with
// fx_best.
sort_perm.sort_by(|a, b| {
errors[*a]
.partial_cmp(&errors[*b])
.unwrap_or(std::cmp::Ordering::Equal)
Expand Down Expand Up @@ -298,30 +310,30 @@ where
// Perform one of possible simplex transformations.
reflection.on_line2_mut(centroid, &simplex[sort_perm[n]], reflection_coeff);
let reflection_not_feasible = reflection.project(dom);
let reflection_error = f
.apply_eval(reflection, &mut self.fx)
.ignore_invalid_value(inf)?;
let reflection_error = f.apply_eval(reflection, fx).ignore_invalid_value(inf)?;

#[allow(clippy::suspicious_else_formatting)]
let (transformation, not_feasible) = if errors[sort_perm[0]] <= reflection_error
&& reflection_error < errors[sort_perm[n - 1]]
{
// Reflected point is neither best nor worst in te new simplex. Just
// replace the worst point.
// Reflected point is neither best nor worst in the new simplex.
// Just replace the worst point.
simplex[sort_perm[n]].copy_from(reflection);
errors[sort_perm[n]] = reflection_error;
(Transformation::Reflection, reflection_not_feasible)
} else if reflection_error < errors[sort_perm[0]] {
fx_best.copy_from(fx);

// Reflected point is better than the current best. Try to go
// farther along this direction.
expansion.on_line2_mut(centroid, &simplex[sort_perm[n]], expansion_coeff);
let expansion_not_feasible = expansion.project(dom);
let expansion_error = f
.apply_eval(expansion, &mut self.fx)
.ignore_invalid_value(inf)?;
let expansion_error = f.apply_eval(expansion, fx).ignore_invalid_value(inf)?;

if expansion_error < reflection_error {
// Expansion indeed help, replace the worst point.
fx_best.copy_from(fx);

// Expansion indeed helped, replace the worst point.
simplex[sort_perm[n]].copy_from(expansion);
errors[sort_perm[n]] = expansion_error;
(Transformation::Expansion, expansion_not_feasible)
Expand All @@ -343,11 +355,13 @@ where
// Try to perform outer contraction.
contraction.on_line2_mut(centroid, &simplex[sort_perm[n]], outer_contraction_coeff);
let contraction_not_feasible = contraction.project(dom);
let contraction_error = f
.apply_eval(contraction, &mut self.fx)
.ignore_invalid_value(inf)?;
let contraction_error = f.apply_eval(contraction, fx).ignore_invalid_value(inf)?;

if contraction_error <= reflection_error {
if contraction_error < errors[sort_perm[0]] {
fx_best.copy_from(fx);
}

// Use the contracted point instead of the reflected point
// because it's better.
simplex[sort_perm[n]].copy_from(contraction);
Expand All @@ -363,11 +377,13 @@ where
// Try to perform inner contraction.
contraction.on_line2_mut(centroid, &simplex[sort_perm[n]], inner_contraction_coeff);
let contraction_not_feasible = contraction.project(dom);
let contraction_error = f
.apply_eval(contraction, &mut self.fx)
.ignore_invalid_value(inf)?;
let contraction_error = f.apply_eval(contraction, fx).ignore_invalid_value(inf)?;

if contraction_error <= errors[sort_perm[n]] {
if contraction_error < errors[sort_perm[0]] {
fx_best.copy_from(fx);
}

// The contracted point is better than the worst point.
simplex[sort_perm[n]].copy_from(contraction);
errors[sort_perm[n]] = contraction_error;
Expand All @@ -387,20 +403,28 @@ where
// Shrink the simplex towards the best point.

contraction.copy_from(&simplex[sort_perm[0]]);
let mut error_best = errors[sort_perm[0]];

for i in 1..=n {
let xi = &mut simplex[sort_perm[i]];
xi.on_line_mut(contraction, shrink_coeff);
errors[sort_perm[i]] =
f.apply_eval(xi, &mut self.fx).ignore_invalid_value(inf)?;
let error = f.apply_eval(xi, fx).ignore_invalid_value(inf)?;
errors[sort_perm[i]] = error;

if error < error_best {
fx_best.copy_from(fx);
error_best = error;
}
}

(Transformation::Shrinkage, false)
}
}
};

// Establish the ordering of simplex points.
sort_perm.sort_unstable_by(|a, b| {
// Establish the ordering of simplex points. Stable sort is important
// for sort_perm[0] being consistent with fx_best.
sort_perm.sort_by(|a, b| {
errors[*a]
.partial_cmp(&errors[*b])
.unwrap_or(std::cmp::Ordering::Equal)
Expand All @@ -416,7 +440,7 @@ where

// Return the best simplex point.
x.copy_from(&simplex[sort_perm[0]]);
let error = f.apply_eval(x, &mut self.fx)?;
// fx corresponding to x is stored in `self.fx_best`.

if transformation == Transformation::Shrinkage || not_feasible {
// Check whether the simplex collapsed or not. It can happen only
Expand All @@ -436,7 +460,7 @@ where
}
}

Ok(error)
Ok(errors[sort_perm[0]])
}
}

Expand Down Expand Up @@ -483,7 +507,7 @@ where
Sfx: StorageMut<<F>::Scalar, <F>::Dim>,
{
self.next_inner(f, dom, x).map(|_| {
fx.copy_from(&self.fx);
fx.copy_from(&self.fx_best);
})
}
}
Expand Down

0 comments on commit fb8e90e

Please sign in to comment.