Skip to content

Commit

Permalink
feat: Improve robustness of Nelder-Mead algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
pnevyk committed Feb 7, 2022
1 parent 2870944 commit 34a7fca
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
17 changes: 17 additions & 0 deletions src/core/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,20 @@ where
Ok(norm)
}
}

/// Extension trait for `Result<F::Scalar, Error>`.
pub trait FunctionResultExt<T> {
/// If the result is [`Error::InvalidValue`], `Ok(default)` is returned
/// instead. The original result is returned otherwise.
fn ignore_invalid_value(self, replace_with: T) -> Self;
}

impl<T> FunctionResultExt<T> for Result<T, Error> {
fn ignore_invalid_value(self, replace_with: T) -> Self {
match self {
Ok(value) => Ok(value),
Err(Error::InvalidValue) => Ok(replace_with),
Err(error) => Err(error),
}
}
}
59 changes: 51 additions & 8 deletions src/solver/nelder_mead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ use num_traits::{One, Zero};
use thiserror::Error;

use crate::{
core::{Domain, Error, Function, Optimizer, Problem, Solver, System, VectorDomainExt},
core::{
Domain, Error, Function, FunctionResultExt, Optimizer, Problem, Solver, System,
VectorDomainExt,
},
derivatives::EPSILON_SQRT,
};

Expand Down Expand Up @@ -211,20 +214,51 @@ where
} = self;

let n = f.dim().value();
let inf = convert(f64::INFINITY);

if simplex.is_empty() {
// Simplex initialization.
simplex.push(x.clone_owned());

// It is important to return early on error before the point is
// added to the simplex.
errors.push(f.apply(x)?);
simplex.push(x.clone_owned());

for j in 0..n {
let mut xi = x.clone_owned();
xi[j] = dom.vars()[j].clamp(xi[j] + scale[j]);

errors.push(f.apply(&xi)?);
let error = match f.apply(&xi) {
Ok(error) => error,
// Do not fail when invalid value is encountered during
// building the simplex. Instead, treat it as infinity so
// that it is worse than (or "equal" to) any other error and
// hope that it gets replaced. The exception is the very
// point provided by the caller as it is expected to be
// valid and it is erroneous situation if it is not.
Err(Error::InvalidValue) => inf,
Err(error) => {
// Clear the simplex so the solver is not in invalid
// state.
simplex.clear();
errors.clear();
return Err(error.into());
}
};

errors.push(error);
simplex.push(xi);
}

let inf_error_count = errors.iter().filter(|e| **e == inf).count();

if inf_error_count >= n / 2 {
// The simplex is too degenerate.
simplex.clear();
errors.clear();
return Err(NelderMeadError::Problem(Error::InvalidValue));
}

sort_perm.extend(0..=n);
sort_perm.sort_unstable_by(|a, b| {
errors[*a]
Expand Down Expand Up @@ -264,7 +298,9 @@ where
// Perform one of possible simplex transformations.
reflection.on_line2_mut(centroid, &simplex[sort_perm[n]], reflection_coeff);
let reflection_not_feasible = reflection.project(dom);
let reflection_error = f.apply_eval(reflection, &mut self.fx)?;
let reflection_error = f
.apply_eval(reflection, &mut self.fx)
.ignore_invalid_value(inf)?;

#[allow(clippy::suspicious_else_formatting)]
let (transformation, not_feasible) = if errors[sort_perm[0]] <= reflection_error
Expand All @@ -280,7 +316,9 @@ where
// farther along this direction.
expansion.on_line2_mut(centroid, &simplex[sort_perm[n]], expansion_coeff);
let expansion_not_feasible = expansion.project(dom);
let expansion_error = f.apply_eval(expansion, &mut self.fx)?;
let expansion_error = f
.apply_eval(expansion, &mut self.fx)
.ignore_invalid_value(inf)?;

if expansion_error < reflection_error {
// Expansion indeed help, replace the worst point.
Expand All @@ -305,7 +343,9 @@ where
// Try to perform outer contraction.
contraction.on_line2_mut(centroid, &simplex[sort_perm[n]], outer_contraction_coeff);
let contraction_not_feasible = contraction.project(dom);
let contraction_error = f.apply_eval(contraction, &mut self.fx)?;
let contraction_error = f
.apply_eval(contraction, &mut self.fx)
.ignore_invalid_value(inf)?;

if contraction_error <= reflection_error {
// Use the contracted point instead of the reflected point
Expand All @@ -323,7 +363,9 @@ where
// Try to perform inner contraction.
contraction.on_line2_mut(centroid, &simplex[sort_perm[n]], inner_contraction_coeff);
let contraction_not_feasible = contraction.project(dom);
let contraction_error = f.apply_eval(contraction, &mut self.fx)?;
let contraction_error = f
.apply_eval(contraction, &mut self.fx)
.ignore_invalid_value(inf)?;

if contraction_error <= errors[sort_perm[n]] {
// The contracted point is better than the worst point.
Expand All @@ -348,7 +390,8 @@ where
for i in 1..=n {
let xi = &mut simplex[sort_perm[i]];
xi.on_line_mut(contraction, shrink_coeff);
errors[sort_perm[i]] = f.apply_eval(xi, &mut self.fx)?;
errors[sort_perm[i]] =
f.apply_eval(xi, &mut self.fx).ignore_invalid_value(inf)?;
}

(Transformation::Shrinkage, false)
Expand Down

0 comments on commit 34a7fca

Please sign in to comment.