Skip to content

Commit

Permalink
Merge pull request #447 from isaacsas/vrj_remake_fix
Browse files Browse the repository at this point in the history
ExtendedJumpArrays remake fix
  • Loading branch information
isaacsas authored Aug 24, 2024
2 parents 9e272f7 + d87c468 commit c2c2580
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 23 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Expand All @@ -38,6 +39,7 @@ RandomNumbers = "1.5"
RecursiveArrayTools = "3.12"
Reexport = "1.0"
SciMLBase = "2.46"
Setfield = "1"
StaticArrays = "1.9"
SymbolicIndexingInterface = "0.3.13"
UnPack = "1.0.2"
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

<!-- [![Coverage Status](https://coveralls.io/repos/github/SciML/JumpProcesses.jl/badge.svg?branch=master)](https://coveralls.io/github/SciML/JumpProcesses.jl?branch=master)
[![codecov](https://codecov.io/gh/SciML/JumpProcesses.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/SciML/JumpProcesses.jl) -->

[![Build Status](https://github.com/SciML/JumpProcesses.jl/workflows/CI/badge.svg)](https://github.com/SciML/JumpProcesses.jl/actions?query=workflow%3ACI)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)
Expand Down
1 change: 1 addition & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using FunctionWrappers, UnPack
using Graphs
using SciMLBase: SciMLBase, isdenseplot
using Base.FastMath: add_fast
using Setfield: @set, @set!

import DiffEqBase: DiscreteCallback, init, solve, solve!, plot_indices, initialize!
import Base: size, getindex, setindex!, length, similar, show, merge!, merge
Expand Down
71 changes: 50 additions & 21 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,37 +86,64 @@ function JumpProblem(p::P, a::A, dj::J, jc::C, vj::J2, rj::J3, mj::J4,
JumpProblem{iip, P, A, C, J, J2, J3, J4, R, K}(p, a, dj, jc, vj, rj, mj, rng, kwargs)
end

# for remaking
######## remaking ######

# for a problem where prob.u0 is an ExtendedJumpArray, create an ExtendedJumpArray that
# aliases and resets prob.u0.jump_u while having newu0 as the new u component.
function remake_extended_u0(prob, newu0, rng)
jump_u = prob.u0.jump_u
ttype = eltype(prob.tspan)
@. jump_u = -randexp(rng, ttype)
ExtendedJumpArray(newu0, jump_u)
end

Base.@pure remaker_of(prob::T) where {T <: JumpProblem} = DiffEqBase.parameterless_type(T)
function DiffEqBase.remake(thing::JumpProblem; kwargs...)
T = remaker_of(thing)
function DiffEqBase.remake(jprob::JumpProblem; kwargs...)
T = remaker_of(jprob)

errmesg = """
JumpProblems can currently only be remade with new u0, p, tspan or prob fields. To change other fields create a new JumpProblem. Feel free to open an issue on JumpProcesses to discuss further.
"""
!issubset(keys(kwargs), (:u0, :p, :tspan, :prob)) && error(errmesg)

if :prob keys(kwargs)
dprob = DiffEqBase.remake(thing.prob; kwargs...)
# Update u0 when we are wrapping via ExtendedJumpArrays. If the user passes an
# ExtendedJumpArray we assume they properly initialized it
prob = jprob.prob
if (prob.u0 isa ExtendedJumpArray) && (:u0 in keys(kwargs))
newu0 = kwargs[:u0]
# if newu0 is of the wrapped type, initialize a new ExtendedJumpArray
if typeof(newu0) == typeof(prob.u0.u)
u0 = remake_extended_u0(prob, newu0, jprob.rng)
_kwargs = @set! kwargs[:u0] = u0
elseif typeof(newu0) != typeof(prob.u0)
error("Passed in u0 is incompatible with current u0 which has type: $(typeof(prob.u0.u)).")
else
_kwargs = kwargs
end
dprob = DiffEqBase.remake(jprob.prob; _kwargs...)
else
dprob = DiffEqBase.remake(jprob.prob; kwargs...)
end

# if the parameters were changed we must remake the MassActionJump too
if (:p keys(kwargs)) && using_params(thing.massaction_jump)
update_parameters!(thing.massaction_jump, dprob.p; kwargs...)
if (:p keys(kwargs)) && using_params(jprob.massaction_jump)
update_parameters!(jprob.massaction_jump, dprob.p; kwargs...)
end
else
any(k -> k in keys(kwargs), (:u0, :p, :tspan)) &&
error("If remaking a JumpProblem you can not pass both prob and any of u0, p, or tspan.")
dprob = kwargs[:prob]

# we can't know if p was changed, so we must remake the MassActionJump
if using_params(thing.massaction_jump)
update_parameters!(thing.massaction_jump, dprob.p; kwargs...)
if using_params(jprob.massaction_jump)
update_parameters!(jprob.massaction_jump, dprob.p; kwargs...)
end
end

T(dprob, thing.aggregator, thing.discrete_jump_aggregation, thing.jump_callback,
thing.variable_jumps, thing.regular_jump, thing.massaction_jump, thing.rng,
thing.kwargs)
T(dprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback,
jprob.variable_jumps, jprob.regular_jump, jprob.massaction_jump, jprob.rng,
jprob.kwargs)
end

# when setindex! is used.
Expand Down Expand Up @@ -272,6 +299,14 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
solkwargs)
end

# extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values,
# of type prob.tspan
function extend_u0(prob, Njumps, rng)
ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:Njumps])
return u0
end

function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps; rng = DEFAULT_RNG)
error("General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs.")
end
Expand All @@ -296,9 +331,7 @@ function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAUL
end
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
u0 = extend_u0(prob, length(jumps), rng)
f = ODEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys,
observed = prob.f.observed)
remake(prob; f, u0)
Expand Down Expand Up @@ -334,8 +367,7 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL
end
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:length(jumps)])
u0 = extend_u0(prob, length(jumps), rng)
f = SDEFunction{isinplace(prob)}(jump_f, jump_g; sys = prob.f.sys,
observed = prob.f.observed)
remake(prob; f, g = jump_g, u0)
Expand All @@ -361,8 +393,7 @@ function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAUL
end
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:length(jumps)])
u0 = extend_u0(prob, length(jumps), rng)
f = DDEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys,
observed = prob.f.observed)
remake(prob; f, u0)
Expand All @@ -389,9 +420,7 @@ function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAUL
end
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
u0 = extend_u0(prob, length(jumps), rng)
f = DAEFunction{isinplace(prob)}(jump_f, sys = prob.f.sys,
observed = prob.f.observed)
remake(prob; f, u0)
Expand Down
28 changes: 27 additions & 1 deletion test/remake_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using JumpProcesses, DiffEqBase
using JumpProcesses, DiffEqBase, OrdinaryDiffEq
using StableRNGs
rng = StableRNG(12345)

Expand Down Expand Up @@ -67,3 +67,29 @@ sol3 = solve(jprob3, SSAStepper())
# test error handling
@test_throws ErrorException jprob4=remake(jprob, prob = dprob2, p = p2)
@test_throws ErrorException jprob5=remake(jprob, aggregator = RSSA())

# test for #446
let
f(du, u, p, t) = (du .= 0; nothing)
prob = ODEProblem(f, [0.0], (0.0, 1.0))
rrate(u, p, t) = u[1]
aaffect!(integrator) = (integrator.u[1] += 1; nothing)
vrj = VariableRateJump(rrate, aaffect!)
jprob = JumpProblem(prob, vrj; rng)
sol = solve(jprob, Tsit5())
@test all(==(0.0), sol[1, :])
u0 = [4.0]
jprob2 = remake(jprob; u0)
@test jprob2.prob.u0 isa ExtendedJumpArray
@test jprob2.prob.u0.u === u0
sol = solve(jprob2, Tsit5())
u = sol[1, :]
@test length(u) > 2
@test all(>(u0[1]), u[3:end])
u0 = deepcopy(jprob2.prob.u0)
u0.u .= 0
jprob3 = remake(jprob2; u0)
sol = solve(jprob3, Tsit5())
@test all(==(0.0), sol[1, :])
@test_throws ErrorException jprob4=remake(jprob, u0 = 1)
end

0 comments on commit c2c2580

Please sign in to comment.