Skip to content

Commit

Permalink
feat!: change the interoperability between System and Function traits
Browse files Browse the repository at this point in the history
  • Loading branch information
pnevyk committed Nov 12, 2023
1 parent 4f8f2ce commit db104b5
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 130 deletions.
66 changes: 13 additions & 53 deletions src/core/function.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use nalgebra::{
allocator::Allocator, storage::Storage, storage::StorageMut, DefaultAllocator, Dynamic,
IsContiguous, Vector,
};
use num_traits::Zero;
use nalgebra::{storage::Storage, Dynamic, IsContiguous, Vector};

use super::{
base::{Problem, ProblemError},
Expand Down Expand Up @@ -61,54 +57,6 @@ pub trait Function: Problem {
) -> Result<Self::Scalar, ProblemError>
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous;

/// Calculate the norm of residuals of the system given values of the
/// variable for cases when the function is actually a system of equations.
///
/// The optimizers should prefer calling this function because the
/// implementation for systems reuse `fx` for calculating the residuals and
/// do not make an unnecessary allocation for it.
fn apply_eval<Sx, Sfx>(
&self,
x: &Vector<Self::Scalar, Dynamic, Sx>,
fx: &mut Vector<Self::Scalar, Dynamic, Sfx>,
) -> Result<Self::Scalar, ProblemError>
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
Sfx: StorageMut<Self::Scalar, Dynamic>,
{
let norm = self.apply(x)?;
fx.fill(Self::Scalar::zero());
fx[0] = norm;
Ok(norm)
}
}

impl<F: System> Function for F
where
DefaultAllocator: Allocator<F::Scalar, Dynamic>,
{
fn apply<Sx>(&self, x: &Vector<Self::Scalar, Dynamic, Sx>) -> Result<Self::Scalar, ProblemError>
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
{
let mut fx = x.clone_owned();
self.apply_eval(x, &mut fx)
}

fn apply_eval<Sx, Sfx>(
&self,
x: &Vector<Self::Scalar, Dynamic, Sx>,
fx: &mut Vector<Self::Scalar, Dynamic, Sfx>,
) -> Result<Self::Scalar, ProblemError>
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
Sfx: StorageMut<Self::Scalar, Dynamic>,
{
self.eval(x, fx)?;
let norm = fx.norm();
Ok(norm)
}
}

/// Extension trait for `Result<F::Scalar, Error>`.
Expand All @@ -127,3 +75,15 @@ impl<T> FunctionResultExt<T> for Result<T, ProblemError> {
}
}
}

impl<F> Function for F
where
F: System,
{
fn apply<Sx>(&self, x: &Vector<Self::Scalar, Dynamic, Sx>) -> Result<Self::Scalar, ProblemError>
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
{
self.norm(x)
}
}
14 changes: 14 additions & 0 deletions src/core/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ pub trait System: Problem {
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
Sfx: StorageMut<Self::Scalar, Dynamic>;

/// Calculate the residuals vector norm.
///
/// The default implementation allocates a temporary vector for the
/// residuals on every call. If you plan to solve the system by an
/// optimizer, consider overriding the default implementation.
fn norm<Sx>(&self, x: &Vector<Self::Scalar, Dynamic, Sx>) -> Result<Self::Scalar, ProblemError>
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
{
let mut fx = x.clone_owned();
self.eval(x, &mut fx)?;
Ok(fx.norm())
}
}

/// A wrapper type for systems that implements a standard mechanism for
Expand Down
40 changes: 10 additions & 30 deletions src/solver/nelder_mead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ impl<F: Problem> NelderMeadOptions<F> {
/// Nelder-Mead solver. See [module](self) documentation for more details.
pub struct NelderMead<F: Problem> {
options: NelderMeadOptions<F>,
fx: OVector<F::Scalar, Dynamic>,
fx_best: OVector<F::Scalar, Dynamic>,
scale: OVector<F::Scalar, Dynamic>,
centroid: OVector<F::Scalar, Dynamic>,
reflection: OVector<F::Scalar, Dynamic>,
Expand Down Expand Up @@ -163,8 +161,6 @@ impl<F: Problem> NelderMead<F> {

Self {
options,
fx: OVector::zeros_generic(dim, U1::name()),
fx_best: OVector::zeros_generic(dim, U1::name()),
scale,
centroid: OVector::zeros_generic(dim, U1::name()),
reflection: OVector::zeros_generic(dim, U1::name()),
Expand Down Expand Up @@ -216,8 +212,6 @@ impl<F: Function> NelderMead<F> {
} = self.options;

let Self {
fx,
fx_best,
scale,
simplex,
errors,
Expand All @@ -237,8 +231,7 @@ impl<F: Function> NelderMead<F> {

// It is important to return early on error before the point is
// added to the simplex.
let mut error_best = f.apply_eval(x, fx)?;
fx_best.copy_from(fx);
let mut error_best = f.apply(x)?;
errors.push(error_best);
simplex.push(x.clone_owned());

Expand All @@ -247,7 +240,7 @@ impl<F: Function> NelderMead<F> {
xi[j] += scale[j];
dom.project_in(&mut xi, j);

let error = match f.apply_eval(&xi, fx) {
let error = match f.apply(&xi) {
Ok(error) => error,
// Do not fail when invalid value is encountered during
// building the simplex. Instead, treat it as infinity so
Expand All @@ -267,7 +260,6 @@ impl<F: Function> NelderMead<F> {
};

if error < error_best {
fx_best.copy_from(fx);
error_best = error;
}

Expand Down Expand Up @@ -332,7 +324,7 @@ impl<F: Function> NelderMead<F> {
// Perform one of possible simplex transformations.
reflection.on_line2_mut(centroid, &simplex[sort_perm[n]], reflection_coeff);
let reflection_not_feasible = dom.project(reflection);
let reflection_error = f.apply_eval(reflection, fx).ignore_invalid_value(inf)?;
let reflection_error = f.apply(reflection).ignore_invalid_value(inf)?;

#[allow(clippy::suspicious_else_formatting)]
let (transformation, not_feasible) = if errors[sort_perm[0]] <= reflection_error
Expand All @@ -344,17 +336,13 @@ impl<F: Function> NelderMead<F> {
errors[sort_perm[n]] = reflection_error;
(Transformation::Reflection, reflection_not_feasible)
} else if reflection_error < errors[sort_perm[0]] {
fx_best.copy_from(fx);

// Reflected point is better than the current best. Try to go
// farther along this direction.
expansion.on_line2_mut(centroid, &simplex[sort_perm[n]], expansion_coeff);
let expansion_not_feasible = dom.project(expansion);
let expansion_error = f.apply_eval(expansion, fx).ignore_invalid_value(inf)?;
let expansion_error = f.apply(expansion).ignore_invalid_value(inf)?;

if expansion_error < reflection_error {
fx_best.copy_from(fx);

// Expansion indeed helped, replace the worst point.
simplex[sort_perm[n]].copy_from(expansion);
errors[sort_perm[n]] = expansion_error;
Expand All @@ -377,13 +365,9 @@ impl<F: Function> NelderMead<F> {
// Try to perform outer contraction.
contraction.on_line2_mut(centroid, &simplex[sort_perm[n]], outer_contraction_coeff);
let contraction_not_feasible = dom.project(contraction);
let contraction_error = f.apply_eval(contraction, fx).ignore_invalid_value(inf)?;
let contraction_error = f.apply(contraction).ignore_invalid_value(inf)?;

if contraction_error <= reflection_error {
if contraction_error < errors[sort_perm[0]] {
fx_best.copy_from(fx);
}

// Use the contracted point instead of the reflected point
// because it's better.
simplex[sort_perm[n]].copy_from(contraction);
Expand All @@ -399,13 +383,9 @@ impl<F: Function> NelderMead<F> {
// Try to perform inner contraction.
contraction.on_line2_mut(centroid, &simplex[sort_perm[n]], inner_contraction_coeff);
let contraction_not_feasible = dom.project(contraction);
let contraction_error = f.apply_eval(contraction, fx).ignore_invalid_value(inf)?;
let contraction_error = f.apply(contraction).ignore_invalid_value(inf)?;

if contraction_error <= errors[sort_perm[n]] {
if contraction_error < errors[sort_perm[0]] {
fx_best.copy_from(fx);
}

// The contracted point is better than the worst point.
simplex[sort_perm[n]].copy_from(contraction);
errors[sort_perm[n]] = contraction_error;
Expand All @@ -430,11 +410,10 @@ impl<F: Function> NelderMead<F> {
for i in 1..=n {
let xi = &mut simplex[sort_perm[i]];
xi.on_line_mut(contraction, shrink_coeff);
let error = f.apply_eval(xi, fx).ignore_invalid_value(inf)?;
let error = f.apply(xi).ignore_invalid_value(inf)?;
errors[sort_perm[i]] = error;

if error < error_best {
fx_best.copy_from(fx);
error_best = error;
}
}
Expand Down Expand Up @@ -525,8 +504,9 @@ impl<F: System + Function> Solver<F> for NelderMead<F> {
Sx: StorageMut<F::Scalar, Dynamic> + IsContiguous,
Sfx: StorageMut<F::Scalar, Dynamic>,
{
self.next_inner(f, dom, x)
.map(|_| fx.copy_from(&self.fx_best))
self.next_inner(f, dom, x)?;
f.eval(x, fx)?;
Ok(())
}
}

Expand Down
Loading

0 comments on commit db104b5

Please sign in to comment.