Skip to content

Commit 993567f

Browse files
committed
Fix logic around computing trajectory state index
1 parent a3b9cbc commit 993567f

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

R/monty.R

+8-8
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ dust_likelihood_monty <- function(obj, packer, initial = NULL, domain = NULL,
110110

111111
domain <- monty::monty_domain_expand(domain, packer)
112112
save_trajectories <- validate_save_trajectories(save_trajectories)
113+
save_trajectories$uninitialised <-
114+
save_trajectories$enabled && !is.null(save_trajectories$subset)
113115
observer <- dust_observer(obj, save_state, save_trajectories$enabled,
114116
save_snapshots)
115117

@@ -150,12 +152,11 @@ dust_likelihood_monty <- function(obj, packer, initial = NULL, domain = NULL,
150152
## object.
151153
density <- function(x) {
152154
pars <- packer_unpack(packer, x)
153-
was_uninitialised <- dust_likelihood_ensure_initialised(obj, pars)
154-
needs_trajectories_index <- was_uninitialised &&
155-
save_trajectories$enabled && !is.null(save_trajectories$subset)
156-
if (needs_trajectories_index) {
155+
dust_likelihood_ensure_initialised(obj, pars)
156+
if (save_trajectories$uninitialised) {
157157
env$save_trajectories$index <-
158158
obj$packer_state$subset(save_trajectories$subset)$index
159+
env$save_trajectories$uninitialised <- FALSE
159160
}
160161
ptr <- obj$ptr
161162
if (!identical(x, attr(ptr, "last_pars"))) {
@@ -186,12 +187,11 @@ dust_likelihood_monty <- function(obj, packer, initial = NULL, domain = NULL,
186187
} else {
187188
density <- function(x) {
188189
pars <- packer_unpack(packer, x)
189-
was_uninitialised <- dust_likelihood_ensure_initialised(obj, pars)
190-
needs_trajectories_index <- was_uninitialised &&
191-
save_trajectories$enabled && !is.null(save_trajectories$subset)
192-
if (needs_trajectories_index) {
190+
dust_likelihood_ensure_initialised(obj, pars)
191+
if (save_trajectories$uninitialised) {
193192
env$save_trajectories$index <-
194193
obj$packer_state$subset(save_trajectories$subset)$index
194+
env$save_trajectories$uninitialised <- FALSE
195195
}
196196
ll <- dust_likelihood_run(
197197
obj,

tests/testthat/test-monty.R

+40
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,43 @@ test_that("can get snapshots from model", {
363363
expect_equal(names(res$observations), "snapshots")
364364
expect_equal(dim(res$observations$snapshots), c(5, 2, 27, 3))
365365
})
366+
367+
368+
test_that("cope with changing the trajectory index", {
369+
d <- data.frame(
370+
time = 1:5,
371+
incidence = c(12, 23, 25, 36, 30))
372+
373+
filter <- dust2::dust_filter_create(sir, time_start = 0,
374+
data = d, dt = 0.25,
375+
n_particles = 200)
376+
377+
packer <- monty::monty_packer(scalar = c("beta", "gamma"),
378+
fixed = list(I0 = 10, N = 1000))
379+
vcv <- diag(2) * 0.01
380+
sampler <- monty::monty_sampler_random_walk(vcv)
381+
prior <- monty::monty_dsl({
382+
beta ~ Exponential(mean = 1)
383+
gamma ~ Exponential(mean = 0.5)
384+
})
385+
386+
likelihood <- dust2::dust_likelihood_monty(filter, packer,
387+
save_trajectories = TRUE)
388+
posterior <- likelihood + prior
389+
samples <- monty::monty_sample(posterior, sampler, n_steps = 3,
390+
initial = c(0.3, 0.1),
391+
n_chains = 2)
392+
expect_equal(
393+
dim(samples$observations$trajectories),
394+
c(5, 5, 3, 2)) # state, time, steps, chains
395+
396+
likelihood <- dust2::dust_likelihood_monty(filter, packer,
397+
save_trajectories = "cases_inc")
398+
posterior <- likelihood + prior
399+
samples <- monty::monty_sample(posterior, sampler, n_steps = 3,
400+
initial = c(0.3, 0.1),
401+
n_chains = 2)
402+
expect_equal(
403+
dim(samples$observations$trajectories),
404+
c(1, 5, 3, 2)) # state, time, steps, chains
405+
})

0 commit comments

Comments
 (0)