Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: add hooks for OverrideInit #517

Draft
wants to merge 29 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1ece4f9
feat: add `get_abstol` and `get_reltol` interface methods
AayushSabharwal Dec 5, 2024
0ac88e0
feat: add `initialize_cache!`
AayushSabharwal Dec 5, 2024
b0a4ba1
feat: implement initialization for polyalg cache
AayushSabharwal Dec 5, 2024
f6f4689
feat: implement initialization for no-init cache
AayushSabharwal Dec 5, 2024
1ff9f7f
feat: implement initialization for first order cache
AayushSabharwal Dec 5, 2024
0377273
feat: implement initialization for `QuasiNewtonCache`
AayushSabharwal Dec 5, 2024
0cd0037
feat: implement initialization for `GeneralizedDFSaneCache`
AayushSabharwal Dec 5, 2024
0befd56
fix: fix `SII.parameter_values` implementation
AayushSabharwal Dec 6, 2024
51a85b8
fixup! feat: add `get_abstol` and `get_reltol` interface methods
AayushSabharwal Dec 6, 2024
4d2979e
fixup! feat: add `initialize_cache!`
AayushSabharwal Dec 6, 2024
1fe4e87
fixup! feat: implement initialization for polyalg cache
AayushSabharwal Dec 6, 2024
7a798f6
fixup! feat: implement initialization for polyalg cache
AayushSabharwal Dec 6, 2024
645cd21
fixup! fixup! feat: add `initialize_cache!`
AayushSabharwal Dec 6, 2024
225ccac
feat: implement initialization for `SimpleNonlinearSolve`
AayushSabharwal Dec 6, 2024
dfd9eee
fixup! feat: implement initialization for `GeneralizedDFSaneCache`
AayushSabharwal Dec 6, 2024
50ee6e5
fix: fix `InternalAPI.reinit_self!` for `GeneralizedDFSaneCache`
AayushSabharwal Dec 6, 2024
226f85f
fixup! feat: implement initialization for `QuasiNewtonCache`
AayushSabharwal Dec 6, 2024
2d3d78e
fixup! feat: implement initialization for first order cache
AayushSabharwal Dec 6, 2024
0f3202d
fixup! feat: implement initialization for no-init cache
AayushSabharwal Dec 6, 2024
1052025
fix: fix `SII.state_values` for `NoInitCache`
AayushSabharwal Dec 6, 2024
ee6897f
feat: run initialiation on `solve!`
AayushSabharwal Dec 6, 2024
94cfcb7
fixup! fixup! feat: add `initialize_cache!`
AayushSabharwal Dec 6, 2024
4993511
fixup! feat: implement initialization for polyalg cache
AayushSabharwal Dec 6, 2024
4883d2d
fixup! fixup! feat: implement initialization for polyalg cache
AayushSabharwal Dec 6, 2024
2c59221
fixup! feat: add `get_abstol` and `get_reltol` interface methods
AayushSabharwal Dec 6, 2024
b95b68d
fixup! fixup! feat: implement initialization for no-init cache
AayushSabharwal Dec 6, 2024
b581b33
fixup! feat: implement initialization for `QuasiNewtonCache`
AayushSabharwal Dec 6, 2024
6f85e87
fixup! feat: implement initialization for `GeneralizedDFSaneCache`
AayushSabharwal Dec 6, 2024
78b4518
fixup! feat: implement initialization for `SimpleNonlinearSolve`
AayushSabharwal Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include("descent/damped_newton.jl")
include("descent/dogleg.jl")
include("descent/geodesic_acceleration.jl")

include("initialization.jl")
include("solve.jl")

# Unexported Public API
Expand Down
11 changes: 10 additions & 1 deletion lib/NonlinearSolveBase/src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ Abstract Type for all NonlinearSolveBase Caches.
`u0` and any additional keyword arguments.
- `SciMLBase.isinplace(cache)`: whether or not the solver is inplace.
- `CommonSolve.step!(cache; kwargs...)`: See [`CommonSolve.step!`](@ref) for more details.
- `get_abstol(cache)`: get the `abstol` provided to the cache.
- `get_reltol(cache)`: get the `reltol` provided to the cache.

Additionally implements `SymbolicIndexingInterface` interface Functions.

Expand Down Expand Up @@ -304,9 +306,16 @@ end

SciMLBase.isinplace(cache::AbstractNonlinearSolveCache) = SciMLBase.isinplace(cache.prob)

function get_abstol(cache::AbstractNonlinearSolveCache)
get_abstol(cache.termination_cache)
end
function get_reltol(cache::AbstractNonlinearSolveCache)
get_reltol(cache.termination_cache)
end

## SII Interface
SII.symbolic_container(cache::AbstractNonlinearSolveCache) = cache.prob
SII.parameter_values(cache::AbstractNonlinearSolveCache) = SII.parameter_values(cache.prob)
SII.parameter_values(cache::AbstractNonlinearSolveCache) = cache.p
SII.state_values(cache::AbstractNonlinearSolveCache) = get_u(cache)

function Base.getproperty(cache::AbstractNonlinearSolveCache, sym::Symbol)
Expand Down
60 changes: 60 additions & 0 deletions lib/NonlinearSolveBase/src/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
struct NonlinearSolveDefaultInit <: SciMLBase.DAEInitializationAlgorithm end

function run_initialization!(cache, initializealg = cache.initializealg, prob = cache.prob)
_run_initialization!(cache, initializealg, prob, Val(SciMLBase.isinplace(cache)))
end

function _run_initialization!(
cache, ::NonlinearSolveDefaultInit, prob, isinplace::Union{Val{true}, Val{false}})
if SciMLBase.has_initialization_data(prob.f) &&
prob.f.initialization_data isa SciMLBase.OverrideInitData
return _run_initialization!(cache, SciMLBase.OverrideInit(), prob, isinplace)
end
return cache, true
end

function _run_initialization!(cache, initalg::SciMLBase.OverrideInit, prob,
isinplace::Union{Val{true}, Val{false}})
if cache isa AbstractNonlinearSolveCache && isdefined(cache.alg, :autodiff)
autodiff = cache.alg.autodiff
else
autodiff = ADTypes.AutoForwardDiff()
end
alg = initialization_alg(prob.f.initialization_data.initializeprob, autodiff)
if alg === nothing && cache isa AbstractNonlinearSolveCache
alg = cache.alg
end
u0, p, success = SciMLBase.get_initial_values(
prob, cache, prob.f, initalg, isinplace; nlsolve_alg = alg,
abstol = get_abstol(cache), reltol = get_reltol(cache))
cache = update_initial_values!(cache, u0, p)
if cache isa AbstractNonlinearSolveCache && isdefined(cache, :retcode) && !success
cache.retcode = ReturnCode.InitialFailure
end

return cache, success
end

function get_abstol(prob::AbstractNonlinearProblem)
get_tolerance(get(prob.kwargs, :abstol, nothing), eltype(SII.state_values(prob)))
end
function get_reltol(prob::AbstractNonlinearProblem)
get_tolerance(get(prob.kwargs, :reltol, nothing), eltype(SII.state_values(prob)))
end

initialization_alg(initprob, autodiff) = nothing

function update_initial_values!(cache::AbstractNonlinearSolveCache, u0, p)
InternalAPI.reinit!(cache; u0, p)
cache.prob = SciMLBase.remake(cache.prob; u0, p)
return cache
end

function update_initial_values!(prob::AbstractNonlinearProblem, u0, p)
return SciMLBase.remake(prob; u0, p)
end

function _run_initialization!(
cache::AbstractNonlinearSolveCache, ::SciMLBase.NoInit, prob, isinplace)
return cache, true
end
32 changes: 28 additions & 4 deletions lib/NonlinearSolveBase/src/polyalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ end
u0
u0_aliased
alias_u0::Bool

initializealg
end

function update_initial_values!(cache::NonlinearSolvePolyAlgorithmCache, u0, p)
foreach(cache.caches) do subcache
update_initial_values!(subcache, u0, p)
end
cache.prob = SciMLBase.remake(cache.prob; u0, p)
return cache
end

function NonlinearSolveBase.get_abstol(cache::NonlinearSolvePolyAlgorithmCache)
NonlinearSolveBase.get_abstol(cache.caches[cache.current])
end
function NonlinearSolveBase.get_reltol(cache::NonlinearSolvePolyAlgorithmCache)
NonlinearSolveBase.get_reltol(cache.caches[cache.current])
end

function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache)
Expand All @@ -67,6 +84,9 @@ end
function SII.state_values(cache::NonlinearSolvePolyAlgorithmCache)
SII.state_values(SII.symbolic_container(cache))
end
function SII.parameter_values(cache::NonlinearSolvePolyAlgorithmCache)
SII.parameter_values(SII.symbolic_container(cache))
end

function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache)
println(io, "NonlinearSolvePolyAlgorithmCache with \
Expand Down Expand Up @@ -97,7 +117,8 @@ end
function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...;
stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000,
internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs...
internalnorm = L2_NORM, alias_u0 = false, verbose = true,
initializealg = NonlinearSolveDefaultInit(), kwargs...
)
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
Expand All @@ -109,18 +130,21 @@ function SciMLBase.__init(
u0_aliased = alias_u0 ? copy(u0) : u0
alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased))

return NonlinearSolvePolyAlgorithmCache(
cache = NonlinearSolvePolyAlgorithmCache(
alg.static_length, prob,
map(alg.algs) do solver
SciMLBase.__init(
prob, solver, args...;
stats, maxtime, internalnorm, alias_u0, verbose, kwargs...
stats, maxtime, internalnorm, alias_u0, verbose,
initializealg = SciMLBase.NoInit(), kwargs...
)
end,
alg, -1, alg.start_index, 0, stats, 0.0, maxtime,
ReturnCode.Default, false, maxiters, internalnorm,
u0, u0_aliased, alias_u0
u0, u0_aliased, alias_u0, initializealg
)
run_initialization!(cache)
return cache
end

@generated function InternalAPI.step!(
Expand Down
65 changes: 62 additions & 3 deletions lib/NonlinearSolveBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ function SciMLBase.__solve(
end

function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
if cache.retcode == ReturnCode.InitialFailure
return SciMLBase.build_solution(
cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, cache.stats, cache.trace
)
end

while not_terminated(cache)
CommonSolve.step!(cache)
end
Expand Down Expand Up @@ -40,6 +47,17 @@ end
sol_syms = [gensym("sol") for i in 1:N]
u_result_syms = [gensym("u_result") for i in 1:N]

push!(calls,
quote
if cache.retcode == ReturnCode.InitialFailure
u = $(SII.state_values)(cache)
return build_solution_less_specialize(
cache.prob, cache.alg, u, $(Utils.evaluate_f)(cache.prob, u);
retcode = cache.retcode
)
end
end)

for i in 1:N
push!(calls,
quote
Expand Down Expand Up @@ -111,7 +129,8 @@ end

@generated function __generated_polysolve(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...;
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs...
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true,
initializealg = NonlinearSolveDefaultInit(), kwargs...
) where {N}
sol_syms = [gensym("sol") for _ in 1:N]
prob_syms = [gensym("prob") for _ in 1:N]
Expand All @@ -123,9 +142,23 @@ end
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing
end
end]

push!(calls,
quote
prob, success = $(run_initialization!)(prob, initializealg, prob)
if !success
u = $(SII.state_values)(prob)
return build_solution_less_specialize(
prob, alg, u, $(Utils.evaluate_f)(prob, u);
retcode = $(ReturnCode.InitialFailure))
end
end)

push!(calls, quote
u0 = prob.u0
u0_aliased = alias_u0 ? zero(u0) : u0
end]
end)
for i in 1:N
cur_sol = sol_syms[i]
push!(calls,
Expand Down Expand Up @@ -246,6 +279,23 @@ end
alg
args
kwargs::Any
initializealg

retcode::ReturnCode.T
end

function get_abstol(cache::NonlinearSolveNoInitCache)
get(cache.kwargs, :abstol, get_tolerance(nothing, eltype(cache.prob.u0)))
end
function get_reltol(cache::NonlinearSolveNoInitCache)
get(cache.kwargs, :reltol, get_tolerance(nothing, eltype(cache.prob.u0)))
end

SII.parameter_values(cache::NonlinearSolveNoInitCache) = SII.parameter_values(cache.prob)
SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob)

function update_parameter_object!(cache::NonlinearSolveNoInitCache, p)
SciMLBase.reinit!(cache, cache.prob.u0, p)
end

get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob)
Expand All @@ -264,11 +314,20 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...;
initializealg = NonlinearSolveDefaultInit(),
kwargs...
)
return NonlinearSolveNoInitCache(prob, alg, args, kwargs)
cache = NonlinearSolveNoInitCache(
prob, alg, args, kwargs, initializealg, ReturnCode.Success)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prob, alg, args, kwargs, initializealg, ReturnCode.Success)
prob, alg, args, kwargs, initializealg, ReturnCode.Default)

Probably should be this

run_initialization!(cache)
return cache
end

function CommonSolve.solve!(cache::NonlinearSolveNoInitCache)
if cache.retcode == ReturnCode.InitialFailure
u = SII.state_values(cache)
return SciMLBase.build_solution(
cache.prob, cache.alg, u, Utils.evaluate_f(cache.prob, u); cache.retcode)
end
return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...)
end
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ const AbsNormModes = Union{
u_diff_cache::uType
end

get_abstol(cache::NonlinearTerminationModeCache) = cache.abstol
get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol

function update_u!!(cache::NonlinearTerminationModeCache, u)
cache.u === nothing && return
if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u)
Expand Down
12 changes: 9 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ end
retcode::ReturnCode.T
force_stop::Bool
kwargs

initializealg
end

function InternalAPI.reinit_self!(
Expand Down Expand Up @@ -121,7 +123,7 @@ function SciMLBase.__init(
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, maxtime = nothing,
termination_condition = nothing, internalnorm = L2_NORM,
linsolve_kwargs = (;), kwargs...
linsolve_kwargs = (;), initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs...
)
@set! alg.autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
provided_jvp_autodiff = alg.jvp_autodiff !== nothing
Expand Down Expand Up @@ -206,13 +208,17 @@ function SciMLBase.__init(
prob, alg, u, fu, J, du; kwargs...
)

return GeneralizedFirstOrderAlgorithmCache(
cache = GeneralizedFirstOrderAlgorithmCache(
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
jac_cache, descent_cache, linesearch_cache, trustregion_cache,
stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs,
initializealg
)
end

NonlinearSolveBase.run_initialization!(cache)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be inside the @static_timeit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

return cache
end

function InternalAPI.step!(
Expand Down
20 changes: 17 additions & 3 deletions lib/NonlinearSolveQuasiNewton/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ end
force_stop::Bool
force_reinit::Bool
kwargs

# Initialization
initializealg
end

function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache)
NonlinearSolveBase.get_abstol(cache.termination_cache)
end
function NonlinearSolveBase.get_reltol(cache::QuasiNewtonCache)
NonlinearSolveBase.get_reltol(cache.termination_cache)
end

function InternalAPI.reinit_self!(
Expand Down Expand Up @@ -130,7 +140,8 @@ function SciMLBase.__init(
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxtime = nothing,
maxiters = 1000, abstol = nothing, reltol = nothing,
linsolve_kwargs = (;), termination_condition = nothing,
internalnorm::F = L2_NORM, kwargs...
internalnorm::F = L2_NORM, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(),
kwargs...
) where {F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand Down Expand Up @@ -204,15 +215,18 @@ function SciMLBase.__init(
uses_jacobian_inverse = inverted_jac, kwargs...
)

return QuasiNewtonCache(
cache = QuasiNewtonCache(
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
initialization_cache, descent_cache, linesearch_cache,
trustregion_cache, update_rule_cache, reinit_rule_cache,
inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime,
alg.max_shrink_times, 0, timer, 0.0, termination_cache, trace,
ReturnCode.Default, false, false, kwargs
ReturnCode.Default, false, false, kwargs, initializealg
)
end

NonlinearSolveBase.run_initialization!(cache)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again with the @static_timeit

return cache
end

function InternalAPI.step!(
Expand Down
Loading
Loading