Skip to content

Commit 8364b88

Browse files
committed
Harmonise filter and unfilter interface (part 1)
1 parent d1bf3b1 commit 8364b88

39 files changed

+817
-998
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
PACKAGE := $(shell grep '^Package:' DESCRIPTION | sed -E 's/^Package:[[:space:]]+//')
2-
RSCRIPT = Rscript --no-init-file
2+
RSCRIPT = Rscript
33

44
all:
55
${RSCRIPT} -e 'pkgbuild::compile_dll()'

NAMESPACE

+8-12
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(dim,dust_system)
4-
S3method(print,dust_filter)
4+
S3method(print,dust_likelihood)
55
S3method(print,dust_system)
66
S3method(print,dust_system_generator)
7-
S3method(print,dust_unfilter)
87
export(dust_compile)
98
export(dust_filter_create)
109
export(dust_filter_data)
11-
export(dust_filter_last_history)
12-
export(dust_filter_last_state)
13-
export(dust_filter_monty)
14-
export(dust_filter_rng_state)
15-
export(dust_filter_run)
16-
export(dust_filter_set_rng_state)
10+
export(dust_likelihood_last_gradient)
11+
export(dust_likelihood_last_history)
12+
export(dust_likelihood_last_state)
13+
export(dust_likelihood_monty)
14+
export(dust_likelihood_rng_state)
15+
export(dust_likelihood_run)
16+
export(dust_likelihood_set_rng_state)
1717
export(dust_ode_control)
1818
export(dust_package)
1919
export(dust_system_compare_data)
@@ -32,8 +32,4 @@ export(dust_system_state)
3232
export(dust_system_time)
3333
export(dust_system_update_pars)
3434
export(dust_unfilter_create)
35-
export(dust_unfilter_last_gradient)
36-
export(dust_unfilter_last_history)
37-
export(dust_unfilter_last_state)
38-
export(dust_unfilter_run)
3935
useDynLib(dust2, .registration = TRUE)

R/cpp11.R

+10-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R/dust.R

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R/interface-filter.R

+20-232
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
##' @inheritParams dust_system_create
2828
##' @inheritParams dust_system_simulate
2929
##'
30-
##' @return A `dust_unfilter` object, which can be used with
31-
##' [dust_unfilter_run]
30+
##' @return A `dust_likelihood` object, which can be used with
31+
##' [dust_likelihood_run]
3232
##'
3333
##' @export
3434
dust_filter_create <- function(generator, time_start, data,
@@ -45,8 +45,6 @@ dust_filter_create <- function(generator, time_start, data,
4545
time_start <- check_time_start(time_start, data$time, call = call)
4646
dt <- check_dt(dt, call = call)
4747

48-
## NOTE: there is no preserve_particle_dimension option here because
49-
## we will always preserve this dimension.
5048
n_groups <- data$n_groups
5149
preserve_group_dimension <- preserve_group_dimension || n_groups > 1
5250

@@ -65,238 +63,38 @@ dust_filter_create <- function(generator, time_start, data,
6563

6664
res <- list2env(
6765
list(inputs = inputs,
66+
initialise = filter_create,
6867
initial_rng_state = filter_rng_state(n_particles, n_groups, seed),
6968
n_particles = n_particles,
7069
n_groups = n_groups,
7170
deterministic = FALSE,
71+
has_adjoint = FALSE,
7272
generator = generator,
7373
methods = generator$methods$filter,
7474
index_state = index_state,
75+
preserve_particle_dimension = TRUE,
7576
preserve_group_dimension = preserve_group_dimension),
7677
parent = emptyenv())
77-
class(res) <- "dust_filter"
78+
class(res) <- c("dust_filter", "dust_likelihood")
7879
res
7980
}
8081

8182

82-
##' Create an independent copy of a filter. The new filter is
83-
##' decoupled from the random number streams of the parent filter. It
84-
##' is also decoupled from the *state size* of the parent filter, so
85-
##' you can use this to create a new filter where the system is
86-
##' fundamentally different but everything else is the same.
87-
##'
88-
##' @title Create copy of filter
89-
##'
90-
##' @inheritParams dust_filter_run
91-
##'
92-
##' @param seed The seed for the filter (see [dust_filter_create])
93-
##'
94-
##' @return A new `dust_filter` object
95-
dust_filter_copy <- function(filter, seed = NULL) {
96-
dst <- new.env(parent = emptyenv())
97-
nms <- c("inputs", "n_particles", "n_groups", "deterministic", "methods",
98-
"index_state", "preserve_group_dimension", "generator")
99-
for (nm in nms) {
100-
dst[[nm]] <- filter[[nm]]
101-
}
102-
dst$initial_rng_state <-
103-
filter_rng_state(filter$n_particles, filter$n_groups, seed)
104-
class(dst) <- "dust_filter"
105-
dst
106-
}
107-
108-
109-
filter_create <- function(filter, pars) {
110-
inputs <- filter$inputs
83+
filter_create <- function(obj, pars) {
84+
inputs <- obj$inputs
11185
list2env(
112-
filter$methods$alloc(pars,
113-
inputs$time_start,
114-
inputs$time,
115-
inputs$dt,
116-
inputs$data,
117-
inputs$n_particles,
118-
inputs$n_groups,
119-
inputs$n_threads,
120-
inputs$index_state,
121-
filter$initial_rng_state),
122-
filter)
123-
filter$initial_rng_state <- NULL
124-
}
125-
126-
127-
##' Run particle filter
128-
##'
129-
##' @title Run particle filter
130-
##'
131-
##' @param filter A `dust_filter` object, created by
132-
##' [dust_filter_create]
133-
##'
134-
##' @param pars Optional parameters to run the filter with. If not
135-
##' provided, parameters are not updated
136-
##'
137-
##' @param initial Optional initial conditions, as a matrix (state x
138-
##' particle) or 3d array (state x particle x group). If not
139-
##' provided, the system initial conditions are used.
140-
##'
141-
##' @param save_history Logical, indicating if the simulation history
142-
##' should be saved while the simulation runs; this has a small
143-
##' overhead in runtime and in memory. History (particle
144-
##' trajectories) will be saved at each time in the filter. If the
145-
##' filter was constructed using a non-`NULL` `index_state` parameter,
146-
##' the history is restricted to these states.
147-
##'
148-
##' @param index_group An optional vector of group indices to run the
149-
##' filter for. You can use this to run a subset of possible
150-
##' groups, once the filter is initialised (this argument must be
151-
##' `NULL` on the **first** call).
152-
##'
153-
##' @return A vector of likelihood values, with as many elements as
154-
##' there are groups.
155-
##'
156-
##' @export
157-
dust_filter_run <- function(filter, pars, initial = NULL,
158-
save_history = FALSE, index_group = NULL) {
159-
check_is_dust_filter(filter)
160-
index_group <- check_index(index_group, max = filter$n_groups,
161-
unique = TRUE)
162-
if (!is.null(pars)) {
163-
pars <- check_pars(pars, filter$n_groups, index_group,
164-
filter$preserve_group_dimension)
165-
}
166-
if (is.null(filter$ptr)) {
167-
if (is.null(pars)) {
168-
cli::cli_abort("'pars' cannot be NULL, as filter is not initialised",
169-
arg = "pars")
170-
}
171-
if (!is.null(index_group)) {
172-
cli::cli_abort(
173-
"'index_group' must be NULL, as filter is not initialised",
174-
arg = "index_group")
175-
}
176-
filter_create(filter, pars)
177-
} else if (!is.null(pars)) {
178-
filter$methods$update_pars(filter$ptr, pars, index_group)
179-
}
180-
filter$methods$run(filter$ptr,
181-
initial,
182-
save_history,
183-
index_group,
184-
filter$preserve_group_dimension)
185-
}
186-
187-
188-
##' Fetch the last history created by running a filter. This
189-
##' errors if the last call to [dust_filter_run] did not use
190-
##' `save_history = TRUE`.
191-
##'
192-
##' @title Fetch last filter history
193-
##'
194-
##' @inheritParams dust_filter_run
195-
##'
196-
##' @param select_random_particle Logical, indicating if we should
197-
##' return a history for one randomly selected particle (rather than
198-
##' the entire history). If this is `TRUE`, the particle will be
199-
##' selected independently for each group, if the filter is grouped.
200-
##' This option is intended to help select a representative
201-
##' trajectory during an MCMC. When `TRUE`, we drop the `particle`
202-
##' dimension of the return value.
203-
##'
204-
##' @return An array. If ungrouped this will have dimensions `state`
205-
##' x `particle` x `time`, and if grouped then `state` x `particle`
206-
##' x `group` x `time`. If `select_random_particle = TRUE`, the
207-
##' second (particle) dimension will be dropped.
208-
##'
209-
##' @export
210-
dust_filter_last_history <- function(filter, index_group = NULL,
211-
select_random_particle = FALSE) {
212-
check_is_dust_filter(filter)
213-
if (is.null(filter$ptr)) {
214-
cli::cli_abort(c(
215-
"History is not current",
216-
i = "Filter has not yet been run"))
217-
}
218-
index_group <- check_index(index_group, max = filter$n_groups,
219-
unique = TRUE)
220-
assert_scalar_logical(select_random_particle)
221-
filter$methods$last_history(filter$ptr, index_group,
222-
select_random_particle,
223-
filter$preserve_group_dimension)
224-
}
225-
226-
227-
##' Get the last state from a filter.
228-
##'
229-
##' @title Get filter state
230-
##'
231-
##' @inheritParams dust_filter_last_history
232-
##'
233-
##' @return An array. If ungrouped this will have dimensions `state`
234-
##' x `particle`, and if grouped then `state` x `particle` x
235-
##' `group`. If `select_random_particle = TRUE`, the second
236-
##' (particle) dimension will be dropped. This is the same as the
237-
##' state returned by [dust_filter_last_history] without the time
238-
##' dimension but also without any state index applied (i.e., we
239-
##' always return all state).
240-
##'
241-
##' @export
242-
dust_filter_last_state <- function(filter, index_group = NULL,
243-
select_random_particle = FALSE) {
244-
check_is_dust_filter(filter)
245-
if (is.null(filter$ptr)) {
246-
cli::cli_abort(c(
247-
"History is not current",
248-
i = "Filter has not yet been run"))
249-
}
250-
index_group <- check_index(index_group, max = filter$n_groups,
251-
unique = TRUE)
252-
assert_scalar_logical(select_random_particle)
253-
filter$methods$last_state(filter$ptr, index_group,
254-
select_random_particle,
255-
filter$preserve_group_dimension)
256-
}
257-
258-
259-
##' Get random number generator (RNG) state from the particle filter.
260-
##'
261-
##' @title Get filter RNG state
262-
##'
263-
##' @inheritParams dust_filter_run
264-
##'
265-
##' @return A raw vector, this could be quite long. Later we will
266-
##' describe how you might reseed a filter or system with this state.
267-
##'
268-
##' @export
269-
dust_filter_rng_state <- function(filter) {
270-
check_is_dust_filter(filter)
271-
if (is.null(filter$ptr)) {
272-
filter$initial_rng_state
273-
} else {
274-
filter$methods$rng_state(filter$ptr)
275-
}
276-
}
277-
278-
279-
##' @param rng_state A raw vector of random number generator state,
280-
##' returned by `dust_filter_rng_state`
281-
##' @rdname dust_filter_rng_state
282-
##' @export
283-
dust_filter_set_rng_state <- function(filter, rng_state) {
284-
check_is_dust_filter(filter)
285-
if (is.null(filter$ptr)) {
286-
assert_raw(rng_state, length(filter$initial_rng_state))
287-
filter$initial_rng_state <- rng_state
288-
} else {
289-
filter$methods$set_rng_state(filter$ptr, rng_state)
290-
}
291-
invisible()
292-
}
293-
294-
295-
check_is_dust_filter <- function(filter, call = parent.frame()) {
296-
if (!inherits(filter, "dust_filter")) {
297-
cli::cli_abort("Expected 'filter' to be a 'dust_filter' object",
298-
arg = "filter", call = call)
299-
}
86+
obj$methods$alloc(pars,
87+
inputs$time_start,
88+
inputs$time,
89+
inputs$dt,
90+
inputs$data,
91+
inputs$n_particles,
92+
inputs$n_groups,
93+
inputs$n_threads,
94+
inputs$index_state,
95+
obj$initial_rng_state),
96+
obj)
97+
obj$initial_rng_state <- NULL
30098
}
30199

302100

@@ -306,13 +104,3 @@ filter_rng_state <- function(n_particles, n_groups, seed) {
306104
n_streams <- max(n_groups, 1) * (1 + n_particles)
307105
monty::monty_rng$new(n_streams = n_streams, seed = seed)$state()
308106
}
309-
310-
311-
##' @export
312-
print.dust_filter <- function(x, ...) {
313-
cli::cli_h1("<dust_filter ({x$generator$name})>")
314-
cli::cli_alert_info(format_dimensions(x))
315-
cli::cli_bullets(c(
316-
i = "This filter runs in {x$generator$properties$time_type} time"))
317-
invisible(x)
318-
}

0 commit comments

Comments
 (0)