Skip to content

Commit

Permalink
Add Halley's method via descent API
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Nov 4, 2024
1 parent 037a07c commit 198cd4d
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 6 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Expand Down Expand Up @@ -113,6 +114,7 @@ StaticArrays = "1.9"
StaticArraysCore = "1.4"
Sundials = "4.23.1"
SymbolicIndexingInterface = "0.3.31"
TaylorDiff = "0.3"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"
Expand Down Expand Up @@ -146,8 +148,9 @@ SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "TaylorDiff", "Test", "Zygote"]
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"

[extensions]
NonlinearSolveBaseBandedMatricesExt = "BandedMatrices"
Expand All @@ -44,6 +45,7 @@ NonlinearSolveBaseLineSearchExt = "LineSearch"
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
NonlinearSolveBaseTaylorDiffExt = "TaylorDiff"

[compat]
ADTypes = "1.9"
Expand Down Expand Up @@ -77,6 +79,7 @@ SparseArrays = "1.10"
SparseMatrixColorings = "0.4.5"
StaticArraysCore = "1.4"
SymbolicIndexingInterface = "0.3.31"
TaylorDiff = "0.3"
Test = "1.10"
TimerOutputs = "0.5.23"
julia = "1.10"
Expand Down
20 changes: 20 additions & 0 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseTaylorDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module NonlinearSolveBaseTaylorDiffExt
using SciMLBase: NonlinearFunction
using NonlinearSolveBase: HalleyDescentCache
import NonlinearSolveBase: evaluate_hvvp
using TaylorDiff: derivative, derivative!
using FastClosures: @closure

function evaluate_hvvp(
hvvp, cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
if iip
binary_f = @closure (y, x) -> f(y, x, p)
derivative!(hvvp, binary_f, cache.fu, u, δu, Val(2))
else
unary_f = Base.Fix2(f, p)
hvvp = derivative(unary_f, u, δu, Val(2))
end
hvvp
end

end
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ include("wrappers.jl")

include("descent/common.jl")
include("descent/newton.jl")
include("descent/halley.jl")
include("descent/steepest.jl")
include("descent/damped_newton.jl")
include("descent/dogleg.jl")
Expand Down
100 changes: 100 additions & 0 deletions lib/NonlinearSolveBase/src/descent/halley.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
HalleyDescent(; linsolve = nothing)
Improve the NewtonDescent with higher-order terms. First compute the descent direction as ``J a = -fu``.
Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.
Note that `import TaylorDiff` is required to use this descent algorithm.
See also [`NewtonDescent`](@ref).
"""
@kwdef @concrete struct HalleyDescent <: AbstractDescentDirection
linsolve = nothing
end

supports_line_search(::HalleyDescent) = true

@concrete mutable struct HalleyDescentCache <: AbstractDescentCache
f
p
δu
δus
b
fu
hvvp
lincache
timer
preinverted_jacobian <: Union{Val{false}, Val{true}}
end

@internal_caches HalleyDescentCache :lincache

function InternalAPI.init(
prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats,
shared = Val(1), pre_inverted::Val = Val(false),
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
timer = get_timer_output(), kwargs...)
@bb δu = similar(u)
@bb b = similar(u)
@bb fu = similar(fu)
@bb hvvp = similar(fu)
δus = Utils.unwrap_val(shared) 1 ? nothing : map(2:Utils.unwrap_val(shared)) do i
@bb δu_ = similar(u)
end
lincache = Utils.unwrap_val(pre_inverted) ? nothing :
construct_linear_solver(
alg, alg.linsolve, J, Utils.safe_vec(fu), Utils.safe_vec(u);
stats, abstol, reltol, linsolve_kwargs...
)
return HalleyDescentCache(
prob.f, prob.p, δu, δus, b, fu, hvvp, lincache, timer, pre_inverted)
end

function InternalAPI.solve!(
cache::HalleyDescentCache, J, fu, u, idx::Val = Val(1);
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...)
δu = SciMLBase.get_du(cache, idx)
skip_solve && return DescentResult(; δu)
if preinverted_jacobian(cache)
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`."
@bb δu = J × vec(fu)
else
@static_timeit cache.timer "linear solve 1" begin
linres = cache.lincache(;
A = J, b = Utils.safe_vec(fu),
kwargs..., linu = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
δu = Utils.restructure(SciMLBase.get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
b = cache.b
# compute the hessian-vector-vector product
hvvp = evaluate_hvvp(cache.hvvp, cache, cache.f, cache.p, u, δu)
# second linear solve, reuse factorization if possible
if preinverted_jacobian(cache)
@bb b = J × vec(hvvp)
else
@static_timeit cache.timer "linear solve 2" begin
linres = cache.lincache(;
A = J, b = Utils.safe_vec(hvvp),
kwargs..., linu = Utils.safe_vec(b),
reuse_A_if_factorization = true)
b = Utils.restructure(cache.b, linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
@bb @. δu = δu * δu / (b / 2 - δu)
set_du!(cache, δu, idx)
cache.b = b
return DescentResult(; δu)
end

evaluate_hvvp(hvvp, cache, f, p, u, δu) = error("not implemented. please import TaylorDiff")
7 changes: 4 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
AbstractTrustRegionMethodCache,
Utils, InternalAPI, get_timer_output, @static_timeit,
update_trace!, L2_NORM,
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
Dogleg
NewtonDescent, DampedNewtonDescent, HalleyDescent,
GeodesicAcceleration, Dogleg
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize
Expand All @@ -31,6 +31,7 @@ using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff # Default Forward Mode AD

include("raphson.jl")
include("halley.jl")
include("gauss_newton.jl")
include("levenberg_marquardt.jl")
include("trust_region.jl")
Expand Down Expand Up @@ -93,7 +94,7 @@ end

@reexport using SciMLBase, NonlinearSolveBase

export NewtonRaphson, PseudoTransient
export NewtonRaphson, Halley, PseudoTransient
export GaussNewton, LevenbergMarquardt, TrustRegion

export RadiusUpdateSchemes
Expand Down
15 changes: 15 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/halley.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing)
An experimental Halley's method implementation. Improves the convergence rate of Newton's method by using second-order derivative information to correct the descent direction.
Currently depends on TaylorDiff.jl to handle the correction terms,
might have more general implementation in the future.
"""
function Halley(; concrete_jac = nothing, linsolve = nothing,
linesearch = missing, autodiff = nothing)
return GeneralizedFirstOrderAlgorithm(;
concrete_jac, name = :Halley, linesearch,
descent = HalleyDescent(; linsolve), autodiff)
end
9 changes: 7 additions & 2 deletions test/23_test_problems_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testsetup module RobustnessTesting
using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test
import TaylorDiff

problems = NonlinearProblemLibrary.problems
dicts = NonlinearProblemLibrary.dicts
Expand Down Expand Up @@ -61,10 +62,14 @@ end
end

@testitem "23 Test Problems: Halley" setup=[RobustnessTesting] tags=[:core] begin
alg_ops = (SimpleHalley(; autodiff = AutoForwardDiff()),)
alg_ops = (
Halley(),
SimpleHalley(; autodiff = AutoForwardDiff())
)

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 5, 15, 16, 18]
broken_tests[alg_ops[1]] = [1, 5, 15, 16]
broken_tests[alg_ops[2]] = [1, 5, 15, 16, 18]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
Expand Down

0 comments on commit 198cd4d

Please sign in to comment.