-
-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] feat: add hooks for OverrideInit
#517
base: master
Are you sure you want to change the base?
Changes from all commits
1ece4f9
0ac88e0
b0a4ba1
f6f4689
1ff9f7f
0377273
0cd0037
0befd56
51a85b8
4d2979e
1fe4e87
7a798f6
645cd21
225ccac
dfd9eee
50ee6e5
226f85f
2d3d78e
0f3202d
1052025
ee6897f
94cfcb7
4993511
4883d2d
2c59221
b95b68d
b581b33
6f85e87
78b4518
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
struct NonlinearSolveDefaultInit <: SciMLBase.DAEInitializationAlgorithm end | ||
|
||
function run_initialization!(cache, initializealg = cache.initializealg, prob = cache.prob) | ||
_run_initialization!(cache, initializealg, prob, Val(SciMLBase.isinplace(cache))) | ||
end | ||
|
||
function _run_initialization!( | ||
cache, ::NonlinearSolveDefaultInit, prob, isinplace::Union{Val{true}, Val{false}}) | ||
if SciMLBase.has_initialization_data(prob.f) && | ||
prob.f.initialization_data isa SciMLBase.OverrideInitData | ||
return _run_initialization!(cache, SciMLBase.OverrideInit(), prob, isinplace) | ||
end | ||
return cache, true | ||
end | ||
|
||
function _run_initialization!(cache, initalg::SciMLBase.OverrideInit, prob, | ||
isinplace::Union{Val{true}, Val{false}}) | ||
if cache isa AbstractNonlinearSolveCache && isdefined(cache.alg, :autodiff) | ||
autodiff = cache.alg.autodiff | ||
else | ||
autodiff = ADTypes.AutoForwardDiff() | ||
end | ||
alg = initialization_alg(prob.f.initialization_data.initializeprob, autodiff) | ||
if alg === nothing && cache isa AbstractNonlinearSolveCache | ||
alg = cache.alg | ||
end | ||
u0, p, success = SciMLBase.get_initial_values( | ||
prob, cache, prob.f, initalg, isinplace; nlsolve_alg = alg, | ||
abstol = get_abstol(cache), reltol = get_reltol(cache)) | ||
cache = update_initial_values!(cache, u0, p) | ||
if cache isa AbstractNonlinearSolveCache && isdefined(cache, :retcode) && !success | ||
cache.retcode = ReturnCode.InitialFailure | ||
end | ||
|
||
return cache, success | ||
end | ||
|
||
function get_abstol(prob::AbstractNonlinearProblem) | ||
get_tolerance(get(prob.kwargs, :abstol, nothing), eltype(SII.state_values(prob))) | ||
end | ||
function get_reltol(prob::AbstractNonlinearProblem) | ||
get_tolerance(get(prob.kwargs, :reltol, nothing), eltype(SII.state_values(prob))) | ||
end | ||
|
||
initialization_alg(initprob, autodiff) = nothing | ||
|
||
function update_initial_values!(cache::AbstractNonlinearSolveCache, u0, p) | ||
InternalAPI.reinit!(cache; u0, p) | ||
cache.prob = SciMLBase.remake(cache.prob; u0, p) | ||
return cache | ||
end | ||
|
||
function update_initial_values!(prob::AbstractNonlinearProblem, u0, p) | ||
return SciMLBase.remake(prob; u0, p) | ||
end | ||
|
||
function _run_initialization!( | ||
cache::AbstractNonlinearSolveCache, ::SciMLBase.NoInit, prob, isinplace) | ||
return cache, true | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,6 +87,8 @@ end | |
retcode::ReturnCode.T | ||
force_stop::Bool | ||
kwargs | ||
|
||
initializealg | ||
end | ||
|
||
function InternalAPI.reinit_self!( | ||
|
@@ -121,7 +123,7 @@ function SciMLBase.__init( | |
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000, | ||
abstol = nothing, reltol = nothing, maxtime = nothing, | ||
termination_condition = nothing, internalnorm = L2_NORM, | ||
linsolve_kwargs = (;), kwargs... | ||
linsolve_kwargs = (;), initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs... | ||
) | ||
@set! alg.autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) | ||
provided_jvp_autodiff = alg.jvp_autodiff !== nothing | ||
|
@@ -206,13 +208,17 @@ function SciMLBase.__init( | |
prob, alg, u, fu, J, du; kwargs... | ||
) | ||
|
||
return GeneralizedFirstOrderAlgorithmCache( | ||
cache = GeneralizedFirstOrderAlgorithmCache( | ||
fu, u, u_cache, prob.p, du, J, alg, prob, globalization, | ||
jac_cache, descent_cache, linesearch_cache, trustregion_cache, | ||
stats, 0, maxiters, maxtime, alg.max_shrink_times, timer, | ||
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs | ||
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs, | ||
initializealg | ||
) | ||
end | ||
|
||
NonlinearSolveBase.run_initialization!(cache) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be inside the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
return cache | ||
end | ||
|
||
function InternalAPI.step!( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,16 @@ end | |
force_stop::Bool | ||
force_reinit::Bool | ||
kwargs | ||
|
||
# Initialization | ||
initializealg | ||
end | ||
|
||
function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache) | ||
NonlinearSolveBase.get_abstol(cache.termination_cache) | ||
end | ||
function NonlinearSolveBase.get_reltol(cache::QuasiNewtonCache) | ||
NonlinearSolveBase.get_reltol(cache.termination_cache) | ||
end | ||
|
||
function InternalAPI.reinit_self!( | ||
|
@@ -130,7 +140,8 @@ function SciMLBase.__init( | |
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxtime = nothing, | ||
maxiters = 1000, abstol = nothing, reltol = nothing, | ||
linsolve_kwargs = (;), termination_condition = nothing, | ||
internalnorm::F = L2_NORM, kwargs... | ||
internalnorm::F = L2_NORM, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), | ||
kwargs... | ||
) where {F} | ||
timer = get_timer_output() | ||
@static_timeit timer "cache construction" begin | ||
|
@@ -204,15 +215,18 @@ function SciMLBase.__init( | |
uses_jacobian_inverse = inverted_jac, kwargs... | ||
) | ||
|
||
return QuasiNewtonCache( | ||
cache = QuasiNewtonCache( | ||
fu, u, u_cache, prob.p, du, J, alg, prob, globalization, | ||
initialization_cache, descent_cache, linesearch_cache, | ||
trustregion_cache, update_rule_cache, reinit_rule_cache, | ||
inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime, | ||
alg.max_shrink_times, 0, timer, 0.0, termination_cache, trace, | ||
ReturnCode.Default, false, false, kwargs | ||
ReturnCode.Default, false, false, kwargs, initializealg | ||
) | ||
end | ||
|
||
NonlinearSolveBase.run_initialization!(cache) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again with the |
||
return cache | ||
end | ||
|
||
function InternalAPI.step!( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should be this