Skip to content

Commit

Permalink
feat!: make systems and functions infallible
Browse files Browse the repository at this point in the history
  • Loading branch information
pnevyk committed Nov 12, 2023
1 parent db104b5 commit 8b6c271
Show file tree
Hide file tree
Showing 15 changed files with 141 additions and 298 deletions.
5 changes: 1 addition & 4 deletions examples/rosenbrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@ impl System for Rosenbrock {
&self,
x: &na::Vector<Self::Scalar, Dynamic, Sx>,
fx: &mut na::Vector<Self::Scalar, Dynamic, Sfx>,
) -> Result<(), ProblemError>
where
) where
Sx: na::storage::Storage<Self::Scalar, Dynamic> + IsContiguous,
Sfx: na::storage::StorageMut<Self::Scalar, Dynamic>,
{
fx[0] = (self.a - x[0]).powi(2);
fx[1] = self.b * (x[1] - x[0].powi(2)).powi(2);

Ok(())
}
}

Expand Down
6 changes: 2 additions & 4 deletions gomez-bench/benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,8 @@ impl<F: TestSystem<Scalar = f64>> GslFunction for GslFunctionWrapper<F> {
na::U1::name(),
);

match self.f.eval(&x, &mut fx) {
Ok(_) => GslStatus::ok(),
Err(_) => GslStatus::err(GslError::BadFunc),
}
self.f.eval(&x, &mut fx);
GslStatus::ok()
}

fn init(&self) -> GslVec {
Expand Down
30 changes: 9 additions & 21 deletions src/analysis/initial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,12 @@ use std::marker::PhantomData;
use nalgebra::{
convert, storage::StorageMut, ComplexField, DimName, Dynamic, IsContiguous, OVector, Vector, U1,
};
use thiserror::Error;

use crate::{
core::{Domain, Problem, ProblemError, System},
derivatives::{Jacobian, JacobianError, EPSILON_SQRT},
core::{Domain, Problem, System},
derivatives::{Jacobian, EPSILON_SQRT},
};

/// Error returned from [`InitialGuessAnalysis`] solver.
#[derive(Debug, Error)]
pub enum InitialGuessAnalysisError {
/// Error that occurred when evaluating the system.
#[error("{0}")]
System(#[from] ProblemError),
/// Error that occurred when computing the Jacobian matrix.
#[error("{0}")]
Jacobian(#[from] JacobianError),
}

/// Initial guesses analyzer. See [module](self) documentation for more details.
pub struct InitialGuessAnalysis<F: Problem> {
nonlinear: Vec<usize>,
Expand All @@ -43,7 +31,7 @@ impl<F: System> InitialGuessAnalysis<F> {
dom: &Domain<F::Scalar>,
x: &mut Vector<F::Scalar, Dynamic, Sx>,
fx: &mut Vector<F::Scalar, Dynamic, Sfx>,
) -> Result<Self, InitialGuessAnalysisError>
) -> Self
where
Sx: StorageMut<F::Scalar, Dynamic> + IsContiguous,
Sfx: StorageMut<F::Scalar, Dynamic>,
Expand All @@ -55,8 +43,8 @@ impl<F: System> InitialGuessAnalysis<F> {
.unwrap_or_else(|| OVector::from_element_generic(dim, U1::name(), convert(1.0)));

// Compute F'(x) in the initial point.
f.eval(x, fx)?;
let jac1 = Jacobian::new(f, x, &scale, fx)?;
f.eval(x, fx);
let jac1 = Jacobian::new(f, x, &scale, fx);

// Compute Newton step.
let mut p = fx.clone_owned();
Expand All @@ -70,8 +58,8 @@ impl<F: System> InitialGuessAnalysis<F> {
*x += p;

// Compute F'(x) after one Newton step.
f.eval(x, fx)?;
let jac2 = Jacobian::new(f, x, &scale, fx)?;
f.eval(x, fx);
let jac2 = Jacobian::new(f, x, &scale, fx);

// Linear variables have no effect on the Jacobian matrix. They can be
// recognized by observing no change in corresponding columns (i.e.,
Expand All @@ -92,10 +80,10 @@ impl<F: System> InitialGuessAnalysis<F> {
.map(|(col, _)| col)
.collect();

Ok(Self {
Self {
nonlinear,
ty: PhantomData,
})
}
}

/// Returns indices of variables that have influence on the Jacobian matrix
Expand Down
17 changes: 0 additions & 17 deletions src/core/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use nalgebra::RealField;
use thiserror::Error;

use super::domain::Domain;

Expand All @@ -13,19 +12,3 @@ pub trait Problem {
/// system is unconstrained.
fn domain(&self) -> Domain<Self::Scalar>;
}

/// Error encountered while applying variables to the problem.
#[derive(Debug, Error)]
pub enum ProblemError {
/// The number of variables does not match the dimensionality of the problem
/// domain.
#[error("invalid dimensionality")]
InvalidDimensionality,
/// An invalid value (NaN, positive or negative infinity) of a residual or
/// the function value occurred.
#[error("invalid value encountered")]
InvalidValue,
/// A custom error specific to the system or function.
#[error("{0}")]
Custom(Box<dyn std::error::Error>),
}
36 changes: 5 additions & 31 deletions src/core/function.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use nalgebra::{storage::Storage, Dynamic, IsContiguous, Vector};

use super::{
base::{Problem, ProblemError},
system::System,
};
use super::{base::Problem, system::System};

/// The trait for defining functions.
///
Expand Down Expand Up @@ -37,50 +34,27 @@ use super::{
///
/// impl Function for Rosenbrock {
/// // Apply trial values of variables to the function.
/// fn apply<Sx>(
/// &self,
/// x: &na::Vector<Self::Scalar, Dynamic, Sx>,
/// ) -> Result<Self::Scalar, ProblemError>
/// fn apply<Sx>(&self, x: &na::Vector<Self::Scalar, Dynamic, Sx>) -> Self::Scalar
/// where
/// Sx: na::storage::Storage<Self::Scalar, Dynamic> + IsContiguous,
/// {
/// // Compute the function value.
/// Ok((self.a - x[0]).powi(2) + self.b * (x[1] - x[0].powi(2)).powi(2))
/// (self.a - x[0]).powi(2) + self.b * (x[1] - x[0].powi(2)).powi(2)
/// }
/// }
/// ```
pub trait Function: Problem {
/// Calculate the function value given values of the variables.
fn apply<Sx>(
&self,
x: &Vector<Self::Scalar, Dynamic, Sx>,
) -> Result<Self::Scalar, ProblemError>
fn apply<Sx>(&self, x: &Vector<Self::Scalar, Dynamic, Sx>) -> Self::Scalar
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous;
}

/// Extension trait for `Result<F::Scalar, Error>`.
pub trait FunctionResultExt<T> {
/// If the result is [`ProblemError::InvalidValue`], `Ok(default)` is
/// returned instead. The original result is returned otherwise.
fn ignore_invalid_value(self, replace_with: T) -> Self;
}

impl<T> FunctionResultExt<T> for Result<T, ProblemError> {
fn ignore_invalid_value(self, replace_with: T) -> Self {
match self {
Ok(value) => Ok(value),
Err(ProblemError::InvalidValue) => Ok(replace_with),
Err(error) => Err(error),
}
}
}

impl<F> Function for F
where
F: System,
{
fn apply<Sx>(&self, x: &Vector<Self::Scalar, Dynamic, Sx>) -> Result<Self::Scalar, ProblemError>
fn apply<Sx>(&self, x: &Vector<Self::Scalar, Dynamic, Sx>) -> Self::Scalar
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
{
Expand Down
9 changes: 3 additions & 6 deletions src/core/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use super::{domain::Domain, function::Function};
/// Standard: Distribution<F::Scalar>,
/// {
/// const NAME: &'static str = "Random";
/// type Error = ProblemError;
/// type Error = std::convert::Infallible;
///
/// fn next<Sx>(
/// &mut self,
Expand All @@ -57,8 +57,7 @@ use super::{domain::Domain, function::Function};
/// dom.sample(x, &mut self.rng);
///
/// // We must compute the value.
/// let value = f.apply(x)?;
/// Ok(value)
/// Ok(f.apply(x))
/// }
/// }
/// ```
Expand All @@ -67,9 +66,7 @@ pub trait Optimizer<F: Function> {
const NAME: &'static str;

/// Error type of the iteration. Represents an invalid operation during
/// computing the next step. It is usual that one of the error kinds is
/// propagation of the [`ProblemError`](super::ProblemError) from the
/// function.
/// computing the next step.
type Error;

/// Computes the next step in the optimization process.
Expand Down
8 changes: 3 additions & 5 deletions src/core/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use super::{domain::Domain, system::System};
/// Standard: Distribution<F::Scalar>,
/// {
/// const NAME: &'static str = "Random";
/// type Error = ProblemError;
/// type Error = std::convert::Infallible;
///
/// fn next<Sx, Sfx>(
/// &mut self,
Expand All @@ -59,7 +59,7 @@ use super::{domain::Domain, system::System};
/// dom.sample(x, &mut self.rng);
///
/// // We must compute the residuals.
/// f.eval(x, fx)?;
/// f.eval(x, fx);
///
/// Ok(())
/// }
Expand All @@ -70,9 +70,7 @@ pub trait Solver<F: System> {
const NAME: &'static str;

/// Error type of the iteration. Represents an invalid operation during
/// computing the next step. It is usual that one of the error kinds is
/// propagation of the [`ProblemError`](super::ProblemError) from the
/// system.
/// computing the next step.
type Error;

/// Computes the next step in the solving process.
Expand Down
22 changes: 7 additions & 15 deletions src/core/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ use nalgebra::{
DefaultAllocator, Dynamic, IsContiguous, OVector, Vector,
};

use super::{
base::{Problem, ProblemError},
domain::Domain,
};
use super::{base::Problem, domain::Domain};

/// The trait for defining equations systems.
///
Expand Down Expand Up @@ -44,16 +41,13 @@ use super::{
/// &self,
/// x: &na::Vector<Self::Scalar, Dynamic, Sx>,
/// fx: &mut na::Vector<Self::Scalar, Dynamic, Sfx>,
/// ) -> Result<(), ProblemError>
/// where
/// ) where
/// Sx: na::storage::Storage<Self::Scalar, Dynamic> + IsContiguous,
/// Sfx: na::storage::StorageMut<Self::Scalar, Dynamic>,
/// {
/// // Compute the residuals of all equations.
/// fx[0] = (self.a - x[0]).powi(2);
/// fx[1] = self.b * (x[1] - x[0].powi(2)).powi(2);
///
/// Ok(())
/// }
/// }
/// ```
Expand All @@ -63,8 +57,7 @@ pub trait System: Problem {
&self,
x: &Vector<Self::Scalar, Dynamic, Sx>,
fx: &mut Vector<Self::Scalar, Dynamic, Sfx>,
) -> Result<(), ProblemError>
where
) where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
Sfx: StorageMut<Self::Scalar, Dynamic>;

Expand All @@ -73,13 +66,13 @@ pub trait System: Problem {
/// The default implementation allocates a temporary vector for the
/// residuals on every call. If you plan to solve the system by an
/// optimizer, consider overriding the default implementation.
fn norm<Sx>(&self, x: &Vector<Self::Scalar, Dynamic, Sx>) -> Result<Self::Scalar, ProblemError>
fn norm<Sx>(&self, x: &Vector<Self::Scalar, Dynamic, Sx>) -> Self::Scalar
where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
{
let mut fx = x.clone_owned();
self.eval(x, &mut fx)?;
Ok(fx.norm())
self.eval(x, &mut fx);
fx.norm()
}
}

Expand Down Expand Up @@ -146,8 +139,7 @@ where
&self,
x: &Vector<Self::Scalar, Dynamic, Sx>,
fx: &mut Vector<Self::Scalar, Dynamic, Sfx>,
) -> Result<(), ProblemError>
where
) where
Sx: Storage<Self::Scalar, Dynamic> + IsContiguous,
Sfx: StorageMut<Self::Scalar, Dynamic>,
{
Expand Down
Loading

0 comments on commit 8b6c271

Please sign in to comment.