From 2002bd4fe880de2f7b976ed091b8b4c4a471142a Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:26:40 +0800 Subject: [PATCH] refactor: Move dual nonlinear solving to NonlinearSolveBase (#513) --- docs/src/basics/faq.md | 4 +- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 98 ++++++++++++++++++- .../src/NonlinearSolveBase.jl | 4 + lib/NonlinearSolveBase/src/forward_diff.jl | 8 ++ lib/NonlinearSolveBase/src/public.jl | 9 ++ lib/NonlinearSolveFirstOrder/Project.toml | 3 +- .../src/NonlinearSolveFirstOrder.jl | 6 +- .../src/forward_diff.jl | 34 +++++++ .../test/misc_tests.jl | 10 ++ lib/NonlinearSolveQuasiNewton/Project.toml | 6 ++ ...NonlinearSolveQuasiNewtonForwardDiffExt.jl | 46 +++++++++ .../Project.toml | 7 ++ ...inearSolveSpectralMethodsForwardDiffExt.jl | 46 +++++++++ src/NonlinearSolve.jl | 11 +-- src/forward_diff.jl | 95 ++++-------------- 15 files changed, 294 insertions(+), 93 deletions(-) create mode 100644 lib/NonlinearSolveBase/src/forward_diff.jl create mode 100644 lib/NonlinearSolveFirstOrder/src/forward_diff.jl create mode 100644 lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl create mode 100644 lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl diff --git a/docs/src/basics/faq.md b/docs/src/basics/faq.md index 4d428250a..9aabc203b 100644 --- a/docs/src/basics/faq.md +++ b/docs/src/basics/faq.md @@ -152,10 +152,10 @@ nothing # hide ``` And boom! Type stable again. We always recommend picking the chunksize via -[`NonlinearSolve.pickchunksize`](@ref), however, if you manually specify the chunksize, it +[`NonlinearSolveBase.pickchunksize`](@ref), however, if you manually specify the chunksize, it must be `≤ length of input`. However, a very large chunksize can lead to excessive compilation times and slowdown. ```@docs -NonlinearSolve.pickchunksize +NonlinearSolveBase.pickchunksize ``` diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index bb3165396..717daa8e4 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -2,17 +2,35 @@ module NonlinearSolveBaseForwardDiffExt using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface -using CommonSolve: solve +using CommonSolve: CommonSolve, solve, solve!, init +using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure -using ForwardDiff: ForwardDiff, Dual +using ForwardDiff: ForwardDiff, Dual, pickchunksize using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, remake -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI, + NonlinearSolvePolyAlgorithm, NonlinearSolveForwardDiffCache const DI = DifferentiationInterface +const GENERAL_SOLVER_TYPES = [ + Nothing, NonlinearSolvePolyAlgorithm +] + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + function NonlinearSolveBase.additional_incompatible_backend_check( prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff}) return !ForwardDiff.can_dual(eltype(prob.u0)) @@ -102,4 +120,78 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution( return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials))) end +for algType in GENERAL_SOLVER_TYPES + @eval function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) + end +end + +function InternalAPI.reinit!( + cache::NonlinearSolveForwardDiffCache, args...; + p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... +) + InternalAPI.reinit!( + cache.cache; p = NonlinearSolveBase.nodual_value(p), + u0 = NonlinearSolveBase.nodual_value(u0), kwargs... + ) + cache.p = p + cache.values_p = NonlinearSolveBase.nodual_value(p) + cache.partials_p = ForwardDiff.partials(p) + return cache +end + +for algType in GENERAL_SOLVER_TYPES + @eval function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) + end +end + +function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) + sol = solve!(cache.cache) + prob = cache.prob + uu = sol.u + + fn = prob isa NonlinearLeastSquaresProblem ? + NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f + + Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p) + Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p) + + z_arr = -Jᵤ \ Jₚ + + sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) + if cache.p isa Number + partials = sumfun((z_arr, cache.p)) + else + partials = sum(sumfun, zip(eachcol(z_arr), cache.p)) + end + + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p) + return SciMLBase.build_solution( + prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +NonlinearSolveBase.nodual_value(x) = x +NonlinearSolveBase.nodual_value(x::Dual) = ForwardDiff.value(x) +NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) + +@inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x)) +@inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) + end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 858137bff..f45ba9242 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -57,6 +57,8 @@ include("descent/geodesic_acceleration.jl") include("solve.jl") +include("forward_diff.jl") + # Unexported Public API @compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance)) @compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution)) @@ -83,4 +85,6 @@ export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogle export NonlinearSolvePolyAlgorithm +export pickchunksize + end diff --git a/lib/NonlinearSolveBase/src/forward_diff.jl b/lib/NonlinearSolveBase/src/forward_diff.jl new file mode 100644 index 000000000..a588aa52d --- /dev/null +++ b/lib/NonlinearSolveBase/src/forward_diff.jl @@ -0,0 +1,8 @@ +@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache + cache + prob + alg + p + values_p + partials_p +end diff --git a/lib/NonlinearSolveBase/src/public.jl b/lib/NonlinearSolveBase/src/public.jl index d076f7873..a9bae2a5e 100644 --- a/lib/NonlinearSolveBase/src/public.jl +++ b/lib/NonlinearSolveBase/src/public.jl @@ -11,6 +11,15 @@ function nonlinearsolve_dual_solution end function nonlinearsolve_∂f_∂p end function nonlinearsolve_∂f_∂u end function nlls_generate_vjp_function end +function nodual_value end + +""" + pickchunksize(x) = pickchunksize(length(x)) + pickchunksize(x::Int) + +Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. +""" +function pickchunksize end # Nonlinear Solve Termination Conditions abstract type AbstractNonlinearTerminationMode end diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index ee2d2c9de..c299b6dc1 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -67,6 +67,7 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" @@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"] diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 1f480fb4b..666cc7435 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -22,14 +22,14 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, get_timer_output, @static_timeit, update_trace!, L2_NORM, NonlinearSolvePolyAlgorithm, NewtonDescent, DampedNewtonDescent, GeodesicAcceleration, - Dogleg + Dogleg, NonlinearSolveForwardDiffCache using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode, NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize using SciMLJacobianOperators: VecJacOperator, JacVecOperator, StatefulJacobianOperator using FiniteDiff: FiniteDiff # Default Finite Difference Method -using ForwardDiff: ForwardDiff # Default Forward Mode AD +using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD include("raphson.jl") include("gauss_newton.jl") @@ -41,6 +41,8 @@ include("poly_algs.jl") include("solve.jl") +include("forward_diff.jl") + @setup_workload begin nonlinear_functions = ( (NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1), diff --git a/lib/NonlinearSolveFirstOrder/src/forward_diff.jl b/lib/NonlinearSolveFirstOrder/src/forward_diff.jl new file mode 100644 index 000000000..86f4b072a --- /dev/null +++ b/lib/NonlinearSolveFirstOrder/src/forward_diff.jl @@ -0,0 +1,34 @@ +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs... +) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end diff --git a/lib/NonlinearSolveFirstOrder/test/misc_tests.jl b/lib/NonlinearSolveFirstOrder/test/misc_tests.jl index 40fcb2c55..79c63f37c 100644 --- a/lib/NonlinearSolveFirstOrder/test/misc_tests.jl +++ b/lib/NonlinearSolveFirstOrder/test/misc_tests.jl @@ -20,3 +20,13 @@ @test sol.retcode == ReturnCode.Success @test jac_calls == 0 end + +@testitem "Dual of BigFloat: Issue #512" tags=[:core] begin + using NonlinearSolveFirstOrder, ForwardDiff + fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p) + u2 = [ForwardDiff.Dual(BigFloat(1.0), 5.0), ForwardDiff.Dual(BigFloat(1.0), 5.0), + ForwardDiff.Dual(BigFloat(1.0), 5.0)] + prob_iip_bf = NonlinearProblem{true}(fn_iip, u2, ForwardDiff.Dual(BigFloat(2.0), 5.0)) + sol = solve(prob_iip_bf, NewtonRaphson()) + @test sol.retcode == ReturnCode.Success +end diff --git a/lib/NonlinearSolveQuasiNewton/Project.toml b/lib/NonlinearSolveQuasiNewton/Project.toml index 2f00863d8..4912e9070 100644 --- a/lib/NonlinearSolveQuasiNewton/Project.toml +++ b/lib/NonlinearSolveQuasiNewton/Project.toml @@ -18,6 +18,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff" + [compat] ADTypes = "1.9.0" Aqua = "0.8" diff --git a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl new file mode 100644 index 000000000..74ec64031 --- /dev/null +++ b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl @@ -0,0 +1,46 @@ +module NonlinearSolveQuasiNewtonForwardDiffExt + +using CommonSolve: CommonSolve, init +using ForwardDiff: ForwardDiff, Dual +using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem + +using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value + +using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs... +) + p = nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +end diff --git a/lib/NonlinearSolveSpectralMethods/Project.toml b/lib/NonlinearSolveSpectralMethods/Project.toml index bb9367554..a248be107 100644 --- a/lib/NonlinearSolveSpectralMethods/Project.toml +++ b/lib/NonlinearSolveSpectralMethods/Project.toml @@ -14,6 +14,12 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +NonlinearSolveSpectralMethodsForwardDiffExt = "ForwardDiff" + [compat] Aqua = "0.8" BenchmarkTools = "1.5.0" @@ -21,6 +27,7 @@ CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" DiffEqBase = "6.158.3" ExplicitImports = "1.5" +ForwardDiff = "0.10.36" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" LineSearch = "0.1.4" diff --git a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl new file mode 100644 index 000000000..930c4861c --- /dev/null +++ b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl @@ -0,0 +1,46 @@ +module NonlinearSolveSpectralMethodsForwardDiffExt + +using CommonSolve: CommonSolve, init +using ForwardDiff: ForwardDiff, Dual +using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem + +using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value + +using NonlinearSolveSpectralMethods: GeneralizedDFSane + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs... +) + p = nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 4c44cc972..a1b759011 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -14,7 +14,7 @@ using LineSearch: BackTracking using NonlinearSolveBase: NonlinearSolveBase, InternalAPI, AbstractNonlinearSolveAlgorithm, AbstractNonlinearSolveCache, Utils, L2_NORM, enable_timer_outputs, disable_timer_outputs, - NonlinearSolvePolyAlgorithm + NonlinearSolvePolyAlgorithm, pickchunksize using Preferences: set_preferences! using SciMLBase: SciMLBase, NLStats, ReturnCode, AbstractNonlinearProblem, @@ -53,15 +53,6 @@ include("extension_algs.jl") include("default.jl") -const ALL_SOLVER_TYPES = [ - Nothing, AbstractNonlinearSolveAlgorithm, - GeneralizedDFSane, GeneralizedFirstOrderAlgorithm, QuasiNewtonAlgorithm, - LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL, - SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL, - CMINPACK, PETScSNES, - NonlinearSolvePolyAlgorithm -] - include("forward_diff.jl") @setup_workload begin diff --git a/src/forward_diff.jl b/src/forward_diff.jl index 5bb98561c..76fdf6f52 100644 --- a/src/forward_diff.jl +++ b/src/forward_diff.jl @@ -1,3 +1,9 @@ +const EXTENSION_SOLVER_TYPES = [ + LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL, + SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL, + CMINPACK, PETScSNES +] + const DualNonlinearProblem = NonlinearProblem{ <:Union{Number, <:AbstractArray}, iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} @@ -10,48 +16,12 @@ const DualAbstractNonlinearProblem = Union{ DualNonlinearProblem, DualNonlinearLeastSquaresProblem } -for algType in ALL_SOLVER_TYPES - @eval function SciMLBase.__solve( - prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... - ) - sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( - prob, alg, args...; kwargs... - ) - dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) - return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original - ) - end -end - -@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache - cache - prob - alg - p - values_p - partials_p -end - -function InternalAPI.reinit!( - cache::NonlinearSolveForwardDiffCache, args...; - p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... -) - InternalAPI.reinit!( - cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs... - ) - cache.p = p - cache.values_p = nodual_value(p) - cache.partials_p = ForwardDiff.partials(p) - return cache -end - -for algType in ALL_SOLVER_TYPES +for algType in EXTENSION_SOLVER_TYPES @eval function SciMLBase.__init( prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) - p = nodual_value(prob.p) - newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) cache = init(newprob, alg, args...; kwargs...) return NonlinearSolveForwardDiffCache( cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) @@ -59,41 +29,16 @@ for algType in ALL_SOLVER_TYPES end end -function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) - sol = solve!(cache.cache) - prob = cache.prob - uu = sol.u - - fn = prob isa NonlinearLeastSquaresProblem ? - NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f - - Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p) - Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p) - - z_arr = -Jᵤ \ Jₚ - - sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) - if cache.p isa Number - partials = sumfun((z_arr, cache.p)) - else - partials = sum(sumfun, zip(eachcol(z_arr), cache.p)) - end - - dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p) - return SciMLBase.build_solution( - prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original +for algType in EXTENSION_SOLVER_TYPES + @eval function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) + end end - -nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) - -""" - pickchunksize(x) = pickchunksize(length(x)) - pickchunksize(x::Int) - -Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. -""" -@inline pickchunksize(x) = pickchunksize(length(x)) -@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)