Skip to content

Commit

Permalink
feat!: return x as part of the result of next and find
Browse files Browse the repository at this point in the history
  • Loading branch information
pnevyk committed Nov 15, 2023
1 parent d2c229b commit b392536
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
4 changes: 2 additions & 2 deletions examples/rosenbrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn main() -> Result<(), String> {

let tolerance = 1e-6;

let result = solver
let (_, fx) = solver
.find(|state| {
println!(
"iter = {}\t|| fx || = {}\tx = {:?}",
Expand All @@ -50,7 +50,7 @@ fn main() -> Result<(), String> {
})
.map_err(|error| format!("{error}"))?;

if result <= tolerance {
if fx <= tolerance {
Ok(())
} else {
Err("did not converge".to_string())
Expand Down
34 changes: 18 additions & 16 deletions src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,21 +286,21 @@ impl<'a, F: System, A: Solver<F>> SolverDriver<'a, F, A> {
/// Does one iteration of the process, returning the norm of the residuals
/// in case of no error.
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Result<F::Field, A::Error> {
pub fn next(&mut self) -> Result<(&[F::Field], F::Field), A::Error> {
self.algo
.solve_next(self.f, &self.dom, &mut self.x, &mut self.fx)?;
Ok(self.fx.norm())
Ok((self.x.as_slice(), self.fx.norm()))
}

/// Runs the iterative process until given stopping criterion is satisfied.
pub fn find<C>(&mut self, stop: C) -> Result<F::Field, A::Error>
pub fn find<C>(&mut self, stop: C) -> Result<(&[F::Field], F::Field), A::Error>
where
C: Fn(SolverIterState<'_, F>) -> bool,
{
let mut iter = 0;

loop {
let norm = self.next()?;
let norm = self.next()?.1;

let state = SolverIterState {
x: &self.x,
Expand All @@ -309,7 +309,7 @@ impl<'a, F: System, A: Solver<F>> SolverDriver<'a, F, A> {
};

if stop(state) {
return Ok(norm);
return Ok((self.x.as_slice(), norm));
}

iter += 1;
Expand Down Expand Up @@ -427,19 +427,21 @@ impl<'a, F: Function, A: Optimizer<F>> OptimizerDriver<'a, F, A> {
/// Does one iteration of the process, returning the function value in case
/// of no error.
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Result<F::Field, A::Error> {
self.algo.opt_next(self.f, &self.dom, &mut self.x)
pub fn next(&mut self) -> Result<(&[F::Field], F::Field), A::Error> {
self.algo
.opt_next(self.f, &self.dom, &mut self.x)
.map(|fx| (self.x.as_slice(), fx))
}

/// Runs the iterative process until given stopping criterion is satisfied.
pub fn find<C>(&mut self, stop: C) -> Result<F::Field, A::Error>
pub fn find<C>(&mut self, stop: C) -> Result<(&[F::Field], F::Field), A::Error>
where
C: Fn(OptimizerIterState<'_, F>) -> bool,
{
let mut iter = 0;

loop {
self.fx = self.next()?;
self.fx = self.next()?.1;

let state = OptimizerIterState {
x: &self.x,
Expand All @@ -448,7 +450,7 @@ impl<'a, F: Function, A: Optimizer<F>> OptimizerDriver<'a, F, A> {
};

if stop(state) {
return Ok(self.fx);
return Ok((self.x.as_slice(), self.fx));
}

iter += 1;
Expand Down Expand Up @@ -514,7 +516,7 @@ mod tests {
.build();

let tolerance = 1e-6;
let norm = solver
let (_, norm) = solver
.find(|state| state.iter() >= 100 || state.norm() < tolerance)
.unwrap();

Expand All @@ -532,7 +534,7 @@ mod tests {
.build();

let tolerance = 1e-6;
let norm = solver
let (_, norm) = solver
.find(|state| state.iter() >= 100 || state.norm() < tolerance)
.unwrap();

Expand Down Expand Up @@ -569,11 +571,11 @@ mod tests {
.build();

let tolerance = 1e-6;
let norm = optimizer
let (_, value) = optimizer
.find(|state| state.iter() >= 100 || state.fx() < tolerance)
.unwrap();

assert!(norm <= tolerance);
assert!(value <= tolerance);
}

#[test]
Expand All @@ -587,11 +589,11 @@ mod tests {
.build();

let tolerance = 1e-6;
let norm = optimizer
let (_, value) = optimizer
.find(|state| state.iter() >= 100 || state.fx() < tolerance)
.unwrap();

assert!(norm <= tolerance);
assert!(value <= tolerance);
}

#[test]
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
//!
//! let tolerance = 1e-6;
//!
//! let result = solver
//! let (x, fx) = solver
//! .find(|state| {
//! println!(
//! "iter = {}\t|| fx || = {}\tx = {:?}",
Expand All @@ -190,8 +190,8 @@
//! })
//! .expect("solver encountered an error");
//!
//! if result <= tolerance {
//! println!("solved");
//! if fx <= tolerance {
//! println!("solved: {x:?}");
//! } else {
//! println!("maximum number of iteration exceeded");
//! }
Expand Down

0 comments on commit b392536

Please sign in to comment.