Skip to content

Commit

Permalink
refactor: Move dual nonlinear solving to NonlinearSolveBase (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY authored Dec 11, 2024
1 parent 2284348 commit 2002bd4
Show file tree
Hide file tree
Showing 15 changed files with 294 additions and 93 deletions.
4 changes: 2 additions & 2 deletions docs/src/basics/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ nothing # hide
```

And boom! Type stable again. We always recommend picking the chunksize via
[`NonlinearSolve.pickchunksize`](@ref), however, if you manually specify the chunksize, it
[`NonlinearSolveBase.pickchunksize`](@ref), however, if you manually specify the chunksize, it
must be `≤ length of input`. However, a very large chunksize can lead to excessive
compilation times and slowdown.

```@docs
NonlinearSolve.pickchunksize
NonlinearSolveBase.pickchunksize
```
98 changes: 95 additions & 3 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,35 @@ module NonlinearSolveBaseForwardDiffExt

using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using CommonSolve: solve
using CommonSolve: CommonSolve, solve, solve!, init
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using ForwardDiff: ForwardDiff, Dual, pickchunksize
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI,
NonlinearSolvePolyAlgorithm, NonlinearSolveForwardDiffCache

const DI = DifferentiationInterface

const GENERAL_SOLVER_TYPES = [
Nothing, NonlinearSolvePolyAlgorithm
]

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function NonlinearSolveBase.additional_incompatible_backend_check(
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
return !ForwardDiff.can_dual(eltype(prob.u0))
Expand Down Expand Up @@ -102,4 +120,78 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
end

for algType in GENERAL_SOLVER_TYPES
@eval function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
end

function InternalAPI.reinit!(
cache::NonlinearSolveForwardDiffCache, args...;
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
)
InternalAPI.reinit!(
cache.cache; p = NonlinearSolveBase.nodual_value(p),
u0 = NonlinearSolveBase.nodual_value(u0), kwargs...
)
cache.p = p
cache.values_p = NonlinearSolveBase.nodual_value(p)
cache.partials_p = ForwardDiff.partials(p)
return cache
end

for algType in GENERAL_SOLVER_TYPES
@eval function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = NonlinearSolveBase.nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end
end

function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob
uu = sol.u

fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)

z_arr = -Jᵤ \ Jₚ

sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if cache.p isa Number
partials = sumfun((z_arr, cache.p))
else
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
end

dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p)
return SciMLBase.build_solution(
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

NonlinearSolveBase.nodual_value(x) = x
NonlinearSolveBase.nodual_value(x::Dual) = ForwardDiff.value(x)
NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

@inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x))
@inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)

end
4 changes: 4 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ include("descent/geodesic_acceleration.jl")

include("solve.jl")

include("forward_diff.jl")

# Unexported Public API
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
Expand All @@ -83,4 +85,6 @@ export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogle

export NonlinearSolvePolyAlgorithm

export pickchunksize

end
8 changes: 8 additions & 0 deletions lib/NonlinearSolveBase/src/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
cache
prob
alg
p
values_p
partials_p
end
9 changes: 9 additions & 0 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ function nonlinearsolve_dual_solution end
function nonlinearsolve_∂f_∂p end
function nonlinearsolve_∂f_∂u end
function nlls_generate_vjp_function end
function nodual_value end

"""
pickchunksize(x) = pickchunksize(length(x))
pickchunksize(x::Int)
Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
"""
function pickchunksize end

# Nonlinear Solve Termination Conditions
abstract type AbstractNonlinearTerminationMode end
Expand Down
3 changes: 2 additions & 1 deletion lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
Expand All @@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
6 changes: 4 additions & 2 deletions lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
Utils, InternalAPI, get_timer_output, @static_timeit,
update_trace!, L2_NORM, NonlinearSolvePolyAlgorithm,
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
Dogleg
Dogleg, NonlinearSolveForwardDiffCache
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize
using SciMLJacobianOperators: VecJacOperator, JacVecOperator, StatefulJacobianOperator

using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff # Default Forward Mode AD
using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD

include("raphson.jl")
include("gauss_newton.jl")
Expand All @@ -41,6 +41,8 @@ include("poly_algs.jl")

include("solve.jl")

include("forward_diff.jl")

@setup_workload begin
nonlinear_functions = (
(NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1),
Expand Down
34 changes: 34 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs...
)
p = NonlinearSolveBase.nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
10 changes: 10 additions & 0 deletions lib/NonlinearSolveFirstOrder/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,13 @@
@test sol.retcode == ReturnCode.Success
@test jac_calls == 0
end

@testitem "Dual of BigFloat: Issue #512" tags=[:core] begin
using NonlinearSolveFirstOrder, ForwardDiff
fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p)
u2 = [ForwardDiff.Dual(BigFloat(1.0), 5.0), ForwardDiff.Dual(BigFloat(1.0), 5.0),
ForwardDiff.Dual(BigFloat(1.0), 5.0)]
prob_iip_bf = NonlinearProblem{true}(fn_iip, u2, ForwardDiff.Dual(BigFloat(2.0), 5.0))
sol = solve(prob_iip_bf, NewtonRaphson())
@test sol.retcode == ReturnCode.Success
end
6 changes: 6 additions & 0 deletions lib/NonlinearSolveQuasiNewton/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff"

[compat]
ADTypes = "1.9.0"
Aqua = "0.8"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module NonlinearSolveQuasiNewtonForwardDiffExt

using CommonSolve: CommonSolve, init
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem

using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value

using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

end
7 changes: 7 additions & 0 deletions lib/NonlinearSolveSpectralMethods/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
NonlinearSolveSpectralMethodsForwardDiffExt = "ForwardDiff"

[compat]
Aqua = "0.8"
BenchmarkTools = "1.5.0"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.158.3"
ExplicitImports = "1.5"
ForwardDiff = "0.10.36"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LineSearch = "0.1.4"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module NonlinearSolveSpectralMethodsForwardDiffExt

using CommonSolve: CommonSolve, init
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem

using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value

using NonlinearSolveSpectralMethods: GeneralizedDFSane

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

end
Loading

0 comments on commit 2002bd4

Please sign in to comment.