Skip to content

Commit

Permalink
fix up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jul 13, 2023
1 parent 84dc2cf commit a4a9808
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/src/tutorials/ensemble_modeling.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ data_train = [
R => (t_train,fullR[1:15]),
]
t_ensem = 0:21
data_train = [
data_ensem = [
S => (t_ensem,fullS[1:22]),
I => (t_ensem,fullI[1:22]),
R => (t_ensem,fullR[1:22]),
Expand Down
32 changes: 19 additions & 13 deletions test/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,39 @@ sol = solve(enprob; saveat = 1);

weights = [0.2, 0.5, 0.3]

fullS = vec(sum(stack(weights .* sol[:,S]),dims=2))
fullI = vec(sum(stack(weights .* sol[:,I]),dims=2))
fullR = vec(sum(stack(weights .* sol[:,R]),dims=2))

t_train = 0:14
data_train = [
S => (t_train,vec(sum(stack([weights[i] * sol[i][S][1:15] for i in 1:3]), dims = 2))),
I => (t_train,vec(sum(stack([weights[i] * sol[i][I][1:15] for i in 1:3]), dims = 2))),
R => (t_train,vec(sum(stack([weights[i] * sol[i][R][1:15] for i in 1:3]), dims = 2))),
S => (t_train,fullS[1:15]),
I => (t_train,fullI[1:15]),
R => (t_train,fullR[1:15]),
]
t_ensem = 0:21
data_ensem = [
S => (t_ensem,vec(sum(stack([weights[i] * sol[i][S][1:22] for i in 1:3]), dims = 2))),
I => (t_ensem,vec(sum(stack([weights[i] * sol[i][I][1:22] for i in 1:3]), dims = 2))),
R => (t_ensem,vec(sum(stack([weights[i] * sol[i][R][1:22] for i in 1:3]), dims = 2))),
S => (t_ensem,fullS[1:22]),
I => (t_ensem,fullI[1:22]),
R => (t_ensem,fullR[1:22]),
]
t_forecast = 0:30
data_forecast = [
S => (t_forecast,vec(sum(stack([weights[i] * sol[i][S][1:end] for i in 1:3]), dims = 2))),
I => (t_forecast,vec(sum(stack([weights[i] * sol[i][I][1:end] for i in 1:3]), dims = 2))),
R => (t_forecast,vec(sum(stack([weights[i] * sol[i][R][1:end] for i in 1:3]), dims = 2))),
S => (t_forecast,fullS),
I => (t_forecast,fullI),
R => (t_forecast,fullR),
]

sol = solve(enprob; saveat = t_ensem);

@test ensemble_weights(sol, data_ensem) [0.2, 0.5, 0.3]

probs = [prob, prob2, prob3]
ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3]
datas = [data_train,data_train,data_train]
probs = (prob, prob2, prob3)
ps = Tuple([β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3)
datas = (data_train,data_train,data_train)
enprobs = bayesian_ensemble(probs, ps, datas)

sol = solve(enprobs; saveat = t_ensem);
ensemble_weights(sol, data_ensem)
ensemble_weights(sol, data_ensem)

bayesian_datafit(probs, ps, datas)

0 comments on commit a4a9808

Please sign in to comment.