-
-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
359 additions
and
512 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
37 changes: 37 additions & 0 deletions
37
lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,40 @@ | ||
module NonlinearSolveFirstOrder | ||
|
||
using Reexport: @reexport | ||
using PrecompileTools: @compile_workload, @setup_workload | ||
|
||
using ArrayInterface: ArrayInterface | ||
using CommonSolve: CommonSolve | ||
using ConcreteStructs: @concrete | ||
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches | ||
using LinearAlgebra: LinearAlgebra, Diagonal, dot, inv, diag | ||
using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase | ||
using MaybeInplace: @bb | ||
using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, | ||
AbstractNonlinearSolveCache, AbstractResetCondition, | ||
AbstractResetConditionCache, AbstractApproximateJacobianStructure, | ||
AbstractJacobianCache, AbstractJacobianInitialization, | ||
AbstractApproximateJacobianUpdateRule, AbstractDescentDirection, | ||
AbstractApproximateJacobianUpdateRuleCache, | ||
Utils, InternalAPI, get_timer_output, @static_timeit, | ||
update_trace!, L2_NORM, NewtonDescent | ||
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode | ||
using SciMLOperators: AbstractSciMLOperator | ||
using StaticArraysCore: StaticArray, Size, MArray | ||
|
||
include("raphson.jl") | ||
include("gauss_newton.jl") | ||
include("levenberg_marquardt.jl") | ||
include("trust_region.jl") | ||
include("pseudo_transient.jl") | ||
|
||
include("solve.jl") | ||
|
||
@reexport using SciMLBase, NonlinearSolveBase | ||
|
||
export NewtonRaphson, PseudoTransient | ||
export GaussNewton, LevenbergMarquardt, TrustRegion | ||
|
||
export GeneralizedFirstOrderAlgorithm | ||
|
||
end |
Empty file.
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
""" | ||
GeneralizedFirstOrderAlgorithm(; | ||
descent, linesearch = missing, | ||
trustregion = missing, autodiff = nothing, vjp_autodiff = nothing, | ||
jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int), | ||
concrete_jac = Val(false), name::Symbol = :unknown | ||
) | ||
This is a Generalization of First-Order (uses Jacobian) Nonlinear Solve Algorithms. The most | ||
common example of this is Newton-Raphson Method. | ||
First Order here refers to the order of differentiation, and should not be confused with the | ||
order of convergence. | ||
### Keyword Arguments | ||
- `trustregion`: Globalization using a Trust Region Method. This needs to follow the | ||
[`NonlinearSolve.AbstractTrustRegionMethod`](@ref) interface. | ||
- `descent`: The descent method to use to compute the step. This needs to follow the | ||
[`NonlinearSolve.AbstractDescentAlgorithm`](@ref) interface. | ||
- `max_shrink_times`: The maximum number of times the trust region radius can be shrunk | ||
before the algorithm terminates. | ||
""" | ||
@concrete struct GeneralizedFirstOrderAlgorithm <: AbstractNonlinearSolveAlgorithm | ||
linesearch | ||
trustregion | ||
descent | ||
max_shrink_times::Int | ||
|
||
autodiff | ||
vjp_autodiff | ||
jvp_autodiff | ||
|
||
concrete_jac <: Union{Val{false}, Val{true}} | ||
name::Symbol | ||
end | ||
|
||
function GeneralizedFirstOrderAlgorithm(; | ||
descent, linesearch = missing, trustregion = missing, autodiff = nothing, | ||
vjp_autodiff = nothing, jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int), | ||
concrete_jac = Val(false), name::Symbol = :unknown) | ||
return GeneralizedFirstOrderAlgorithm( | ||
linesearch, trustregion, descent, max_shrink_times, | ||
autodiff, vjp_autodiff, jvp_autodiff, | ||
concrete_jac, name | ||
) | ||
end | ||
|
||
@concrete mutable struct GeneralizedFirstOrderAlgorithmCache <: AbstractNonlinearSolveCache | ||
# Basic Requirements | ||
fu | ||
u | ||
u_cache | ||
p | ||
du # Aliased to `get_du(descent_cache)` | ||
J # Aliased to `jac_cache.J` | ||
alg <: GeneralizedFirstOrderAlgorithm | ||
prob <: AbstractNonlinearProblem | ||
globalization <: Union{Val{:LineSearch}, Val{:TrustRegion}, Val{:None}} | ||
|
||
# Internal Caches | ||
jac_cache | ||
descent_cache | ||
linesearch_cache | ||
trustregion_cache | ||
|
||
# Counters | ||
stats::NLStats | ||
nsteps::Int | ||
maxiters::Int | ||
maxtime | ||
max_shrink_times::Int | ||
|
||
# Timer | ||
timer | ||
total_time::Float64 | ||
|
||
# State Affect | ||
make_new_jacobian::Bool | ||
|
||
# Termination & Tracking | ||
termination_cache | ||
trace | ||
retcode::ReturnCode.T | ||
force_stop::Bool | ||
kwargs | ||
end | ||
|
||
# XXX: Implement | ||
# function __reinit_internal!( | ||
# cache::GeneralizedFirstOrderAlgorithmCache{iip}, args...; p = cache.p, u0 = cache.u, | ||
# alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...) where {iip} | ||
# if iip | ||
# recursivecopy!(cache.u, u0) | ||
# cache.prob.f(cache.fu, cache.u, p) | ||
# else | ||
# cache.u = __maybe_unaliased(u0, alias_u0) | ||
# set_fu!(cache, cache.prob.f(cache.u, p)) | ||
# end | ||
# cache.p = p | ||
|
||
# __reinit_internal!(cache.stats) | ||
# cache.nsteps = 0 | ||
# cache.maxiters = maxiters | ||
# cache.maxtime = maxtime | ||
# cache.total_time = 0.0 | ||
# cache.force_stop = false | ||
# cache.retcode = ReturnCode.Default | ||
# cache.make_new_jacobian = true | ||
|
||
# reset!(cache.trace) | ||
# reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...) | ||
# reset_timer!(cache.timer) | ||
# end | ||
|
||
NonlinearSolveBase.@internal_caches(GeneralizedFirstOrderAlgorithmCache, | ||
:jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache) | ||
|
||
# function SciMLBase.__init( | ||
# prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm, | ||
# args...; stats = empty_nlstats(), alias_u0 = false, maxiters = 1000, | ||
# abstol = nothing, reltol = nothing, maxtime = nothing, | ||
# termination_condition = nothing, internalnorm = L2_NORM, | ||
# linsolve_kwargs = (;), kwargs...) where {uType, iip} | ||
# autodiff = select_jacobian_autodiff(prob, alg.autodiff) | ||
# jvp_autodiff = if alg.jvp_autodiff === nothing && alg.autodiff !== nothing && | ||
# (ADTypes.mode(alg.autodiff) isa ADTypes.ForwardMode || | ||
# ADTypes.mode(alg.autodiff) isa ADTypes.ForwardOrReverseMode) | ||
# select_forward_mode_autodiff(prob, alg.autodiff) | ||
# else | ||
# select_forward_mode_autodiff(prob, alg.jvp_autodiff) | ||
# end | ||
# vjp_autodiff = if alg.vjp_autodiff === nothing && alg.autodiff !== nothing && | ||
# (ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode || | ||
# ADTypes.mode(alg.autodiff) isa ADTypes.ForwardOrReverseMode) | ||
# select_reverse_mode_autodiff(prob, alg.autodiff) | ||
# else | ||
# select_reverse_mode_autodiff(prob, alg.vjp_autodiff) | ||
# end | ||
|
||
# timer = get_timer_output() | ||
# @static_timeit timer "cache construction" begin | ||
# (; f, u0, p) = prob | ||
# u = __maybe_unaliased(u0, alias_u0) | ||
# fu = evaluate_f(prob, u) | ||
# @bb u_cache = copy(u) | ||
|
||
# linsolve = get_linear_solver(alg.descent) | ||
|
||
# abstol, reltol, termination_cache = NonlinearSolveBase.init_termination_cache( | ||
# prob, abstol, reltol, fu, u, termination_condition, Val(:regular)) | ||
# linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs) | ||
|
||
# jac_cache = construct_jacobian_cache( | ||
# prob, alg, f, fu, u, p; stats, autodiff, linsolve, jvp_autodiff, vjp_autodiff) | ||
# J = jac_cache(nothing) | ||
|
||
# descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol, | ||
# reltol, internalnorm, linsolve_kwargs, timer) | ||
# du = get_du(descent_cache) | ||
|
||
# has_linesearch = alg.linesearch !== missing && alg.linesearch !== nothing | ||
# has_trustregion = alg.trustregion !== missing && alg.trustregion !== nothing | ||
|
||
# if has_trustregion && has_linesearch | ||
# error("TrustRegion and LineSearch methods are algorithmically incompatible.") | ||
# end | ||
|
||
# GB = :None | ||
# linesearch_cache = nothing | ||
# trustregion_cache = nothing | ||
|
||
# if has_trustregion | ||
# supports_trust_region(alg.descent) || error("Trust Region not supported by \ | ||
# $(alg.descent).") | ||
# trustregion_cache = __internal_init( | ||
# prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs..., | ||
# autodiff, jvp_autodiff, vjp_autodiff) | ||
# GB = :TrustRegion | ||
# end | ||
|
||
# if has_linesearch | ||
# supports_line_search(alg.descent) || error("Line Search not supported by \ | ||
# $(alg.descent).") | ||
# linesearch_cache = init( | ||
# prob, alg.linesearch, fu, u; stats, autodiff = jvp_autodiff, kwargs...) | ||
# GB = :LineSearch | ||
# end | ||
|
||
# trace = init_nonlinearsolve_trace( | ||
# prob, alg, u, fu, ApplyArray(__zero, J), du; kwargs...) | ||
|
||
# return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}( | ||
# fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache, | ||
# trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer, | ||
# 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs) | ||
# end | ||
# end | ||
|
||
# function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB}; | ||
# recompute_jacobian::Union{Nothing, Bool} = nothing, kwargs...) where {iip, GB} | ||
# @static_timeit cache.timer "jacobian" begin | ||
# if (recompute_jacobian === nothing || recompute_jacobian) && cache.make_new_jacobian | ||
# J = cache.jac_cache(cache.u) | ||
# new_jacobian = true | ||
# else | ||
# J = cache.jac_cache(nothing) | ||
# new_jacobian = false | ||
# end | ||
# end | ||
|
||
# @static_timeit cache.timer "descent" begin | ||
# if cache.trustregion_cache !== nothing && | ||
# hasfield(typeof(cache.trustregion_cache), :trust_region) | ||
# descent_result = __internal_solve!( | ||
# cache.descent_cache, J, cache.fu, cache.u; new_jacobian, | ||
# trust_region = cache.trustregion_cache.trust_region, cache.kwargs...) | ||
# else | ||
# descent_result = __internal_solve!( | ||
# cache.descent_cache, J, cache.fu, cache.u; new_jacobian, cache.kwargs...) | ||
# end | ||
# end | ||
|
||
# if !descent_result.linsolve_success | ||
# if new_jacobian | ||
# # Jacobian Information is current and linear solve failed terminate the solve | ||
# cache.retcode = ReturnCode.InternalLinearSolveFailed | ||
# cache.force_stop = true | ||
# return | ||
# else | ||
# # Jacobian Information is not current and linear solve failed, recompute | ||
# # Jacobian | ||
# if !haskey(cache.kwargs, :verbose) || cache.kwargs[:verbose] | ||
# @warn "Linear Solve Failed but Jacobian Information is not current. \ | ||
# Retrying with updated Jacobian." | ||
# end | ||
# # In the 2nd call the `new_jacobian` is guaranteed to be `true`. | ||
# cache.make_new_jacobian = true | ||
# __step!(cache; recompute_jacobian = true, kwargs...) | ||
# return | ||
# end | ||
# end | ||
|
||
# δu, descent_intermediates = descent_result.δu, descent_result.extras | ||
|
||
# if descent_result.success | ||
# cache.make_new_jacobian = true | ||
# if GB === :LineSearch | ||
# @static_timeit cache.timer "linesearch" begin | ||
# linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu) | ||
# linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode) | ||
# α = linesearch_sol.step_size | ||
# end | ||
# if linesearch_failed | ||
# cache.retcode = ReturnCode.InternalLineSearchFailed | ||
# cache.force_stop = true | ||
# end | ||
# @static_timeit cache.timer "step" begin | ||
# @bb axpy!(α, δu, cache.u) | ||
# evaluate_f!(cache, cache.u, cache.p) | ||
# end | ||
# elseif GB === :TrustRegion | ||
# @static_timeit cache.timer "trustregion" begin | ||
# tr_accepted, u_new, fu_new = __internal_solve!( | ||
# cache.trustregion_cache, J, cache.fu, | ||
# cache.u, δu, descent_intermediates) | ||
# if tr_accepted | ||
# @bb copyto!(cache.u, u_new) | ||
# @bb copyto!(cache.fu, fu_new) | ||
# α = true | ||
# else | ||
# α = false | ||
# cache.make_new_jacobian = false | ||
# end | ||
# if hasfield(typeof(cache.trustregion_cache), :shrink_counter) && | ||
# cache.trustregion_cache.shrink_counter > cache.max_shrink_times | ||
# cache.retcode = ReturnCode.ShrinkThresholdExceeded | ||
# cache.force_stop = true | ||
# end | ||
# end | ||
# elseif GB === :None | ||
# @static_timeit cache.timer "step" begin | ||
# @bb axpy!(1, δu, cache.u) | ||
# evaluate_f!(cache, cache.u, cache.p) | ||
# end | ||
# α = true | ||
# else | ||
# error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \ | ||
# :TrustRegion, :None)") | ||
# end | ||
# check_and_update!(cache, cache.fu, cache.u, cache.u_cache) | ||
# else | ||
# α = false | ||
# cache.make_new_jacobian = false | ||
# end | ||
|
||
# update_trace!(cache, α) | ||
# @bb copyto!(cache.u_cache, cache.u) | ||
|
||
# callback_into_cache!(cache) | ||
|
||
# return nothing | ||
# end |
Empty file.
Oops, something went wrong.