diff --git a/src/datafit.jl b/src/datafit.jl index 29e4f1fa..f47e02a0 100644 --- a/src/datafit.jl +++ b/src/datafit.jl @@ -206,24 +206,24 @@ function bayes_unpack_data(prob, p::AbstractVector{<:Pair}) (pdist, IndexKeyMap(prob, pkeys)) end -Turing.@model function bayesianODE(prob, t, pdist, pkeys, data, noise_prior) +Turing.@model function bayesianODE(prob, alg, t, pdist, pkeys, data, datamap, noise_prior) σ ~ noise_prior pprior ~ product_distribution(pdist) prob = _remake(prob, (prob.tspan[1], t[end]), pkeys, pprior) - sol = solve(prob, saveat = t) + sol = solve(prob, alg, saveat = t) if !SciMLBase.successful_retcode(sol) Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf) return nothing end for i in eachindex(data) - data[i].second ~ MvNormal(sol[data[i].first], σ^2 * I) + data[i] ~ MvNormal(datamap(sol), σ^2 * I) end return nothing end -Turing.@model function bayesianODE(prob, +Turing.@model function bayesianODE(prob, alg, pdist, pkeys, ts, @@ -236,7 +236,7 @@ Turing.@model function bayesianODE(prob, pprior ~ product_distribution(pdist) prob = _remake(prob, (prob.tspan[1], lastt), pkeys, pprior) - sol = solve(prob) + sol = solve(prob, alg) if !SciMLBase.successful_retcode(sol) Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf) return nothing @@ -264,18 +264,19 @@ end Base.length(ws::WeightedSol) = length(first(ws.sols)) Base.size(ws::WeightedSol) = (length(first(ws.sols)),) function Base.getindex(ws::WeightedSol{T}, i::Int) where {T} - s = zero(T) - w = zero(T) - for j in eachindex(ws.weights) + s::T = zero(T) + w::T = zero(T) + @inbounds for j in eachindex(ws.weights) w += ws.weights[j] s += ws.weights[j] * ws.sols[j][i] end return s + (one(T) - w) * ws.sols[end][i] end -function WeightedSol(sols, select, weights) - T = eltype(weights) - s = map(Base.Fix2(getindex, select), sols) - WeightedSol{T}(s, weights) +function WeightedSol(sols, select, i::Int, weights) + s = map(sols, select) do sol, sel + @view(sol[sel.indices[i], :]) + end + WeightedSol{eltype(weights)}(s, weights) end function bayes_unpack_data(probs, p::Tuple{Vararg{<:AbstractVector{<:Pair}}}, data) pdist, pkeys = bayes_unpack_data(probs, p) @@ -305,25 +306,27 @@ function flatten(x::Tuple) reduce(vcat, x), Grouper(map(length, x)) end -function getsols(probs, probspkeys, ppriors, t::AbstractArray) - map(probs, probspkeys, ppriors) do prob, pkeys, pprior +function getsols(probs, algs, probspkeys, ppriors, t::AbstractArray) + map(probs, algs, probspkeys, ppriors) do prob, alg, pkeys, pprior newprob = _remake(prob, (prob.tspan[1], t[end]), pkeys, pprior) - solve(newprob, saveat = t) + solve(newprob, alg, saveat = t) end end -function getsols(probs, probspkeys, ppriors, lastt::Number) - map(probs, probspkeys, ppriors) do prob, pkeys, pprior +function getsols(probs, algs, probspkeys, ppriors, lastt::Number) + map(probs, algs, probspkeys, ppriors) do prob, alg, pkeys, pprior newprob = _remake(prob, (prob.tspan[1], lastt), pkeys, pprior) - solve(newprob) + solve(newprob, alg) end end Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, + algs, t, pdist, grouppriorsfunc, probspkeys, data, + datamaps, noise_prior) σ ~ noise_prior ppriors ~ product_distribution(pdist) @@ -331,17 +334,18 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, Nprobs = length(probs) Nprobs⁻¹ = inv(Nprobs) weights ~ MvNormal(Distributions.Fill(Nprobs⁻¹, Nprobs - 1), Nprobs⁻¹) - sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), t) + sols = getsols(probs, algs, probspkeys, grouppriorsfunc(ppriors), t) if !all(SciMLBase.successful_retcode, sols) Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf) return nothing end for i in eachindex(data) - data[i].second ~ MvNormal(WeightedSol(sols, data[i].first, weights), σ^2 * I) + data[i] ~ MvNormal(WeightedSol(sols, datamaps, i, weights), σ^2 * I) end return nothing end Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, + algs, pdist, grouppriorsfunc, probspkeys, @@ -353,7 +357,7 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector}, σ ~ noise_prior ppriors ~ product_distribution(pdist) - sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), lastt) + sols = getsols(probs, algs, probspkeys, grouppriorsfunc(ppriors), lastt) Nprobs = length(probs) Nprobs⁻¹ = inv(Nprobs) @@ -411,7 +415,14 @@ function bayesian_datafit(prob, nchains = 4, niter = 1000) (pdist, pkeys) = bayes_unpack_data(prob, p) - model = bayesianODE(prob, t, pdist, pkeys, data, noise_prior) + model = bayesianODE(prob, + first(default_algorithm(prob)), + t, + pdist, + pkeys, + last.(data), + IndexKeyMap(prob, data), + noise_prior) chain = Turing.sample(model, Turing.NUTS(0.65), mcmcensemble, @@ -430,7 +441,15 @@ function bayesian_datafit(prob, nchains = 4, niter = 1_000) pdist, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(prob, p, data) - model = bayesianODE(prob, pdist, pkeys, ts, lastt, timeseries, datakeys, noise_prior) + model = bayesianODE(prob, + first(default_algorithm(prob)), + pdist, + pkeys, + ts, + lastt, + timeseries, + datakeys, + noise_prior) chain = Turing.sample(model, Turing.NUTS(0.65), mcmcensemble, @@ -451,7 +470,10 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector}, (pdist_, pkeys) = bayes_unpack_data(p) pdist, grouppriorsfunc = flatten(pdist_) - model = ensemblebayesianODE(probs, t, pdist, grouppriorsfunc, pkeys, data, noise_prior) + model = ensemblebayesianODE(probs, + map(first ∘ default_algorithm, probs), + t, pdist, grouppriorsfunc, pkeys, last.(data), + map(Base.Fix2(IndexKeyMap, data), probs), noise_prior) chain = Turing.sample(model, Turing.NUTS(0.65), mcmcensemble, @@ -472,6 +494,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector}, pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(p, data) pdist, grouppriorsfunc = flatten(pdist_) model = ensemblebayesianODE(probs, + map(first ∘ default_algorithm, probs), pdist, grouppriorsfunc, pkeys, diff --git a/src/keyindexmap.jl b/src/keyindexmap.jl index a2a863af..b773bdf9 100644 --- a/src/keyindexmap.jl +++ b/src/keyindexmap.jl @@ -3,6 +3,7 @@ struct IndexKeyMap indices::Vector{Int} end +# probs support function IndexKeyMap(prob, keys) params = ModelingToolkit.parameters(prob.f.sys) indices = Vector{Int}(undef, length(keys)) @@ -12,7 +13,8 @@ function IndexKeyMap(prob, keys) return IndexKeyMap(indices) end -Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob, v::AbstractVector) +Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob::SciMLBase.AbstractDEProblem, + v::AbstractVector) @boundscheck checkbounds(v, length(ikm.indices)) def = prob.p ret = Vector{Base.promote_eltype(v, def)}(undef, length(def)) @@ -22,8 +24,20 @@ Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob, v::AbstractVector) end return ret end - function _remake(prob, tspan, ikm::IndexKeyMap, pprior) p = ikm(prob, pprior) remake(prob; tspan, p) end + +# data support +function IndexKeyMap(prob, data::AbstractVector{<:Pair}) + states = ModelingToolkit.states(prob.f.sys) + indices = Vector{Int}(undef, length(data)) + for i in eachindex(data) + indices[i] = findfirst(Base.Fix1(isequal, data[i].first), states) + end + return IndexKeyMap(indices) +end +function (ikm::IndexKeyMap)(sol::SciMLBase.AbstractTimeseriesSolution) + (@view(sol[i, :]) for i in ikm.indices) +end