27
27
# #' @inheritParams dust_system_create
28
28
# #' @inheritParams dust_system_simulate
29
29
# #'
30
- # #' @return A `dust_unfilter ` object, which can be used with
31
- # #' [dust_unfilter_run ]
30
+ # #' @return A `dust_likelihood ` object, which can be used with
31
+ # #' [dust_likelihood_run ]
32
32
# #'
33
33
# #' @export
34
34
dust_filter_create <- function (generator , time_start , data ,
@@ -45,8 +45,6 @@ dust_filter_create <- function(generator, time_start, data,
45
45
time_start <- check_time_start(time_start , data $ time , call = call )
46
46
dt <- check_dt(dt , call = call )
47
47
48
- # # NOTE: there is no preserve_particle_dimension option here because
49
- # # we will always preserve this dimension.
50
48
n_groups <- data $ n_groups
51
49
preserve_group_dimension <- preserve_group_dimension || n_groups > 1
52
50
@@ -65,238 +63,38 @@ dust_filter_create <- function(generator, time_start, data,
65
63
66
64
res <- list2env(
67
65
list (inputs = inputs ,
66
+ initialise = filter_create ,
68
67
initial_rng_state = filter_rng_state(n_particles , n_groups , seed ),
69
68
n_particles = n_particles ,
70
69
n_groups = n_groups ,
71
70
deterministic = FALSE ,
71
+ has_adjoint = FALSE ,
72
72
generator = generator ,
73
73
methods = generator $ methods $ filter ,
74
74
index_state = index_state ,
75
+ preserve_particle_dimension = TRUE ,
75
76
preserve_group_dimension = preserve_group_dimension ),
76
77
parent = emptyenv())
77
- class(res ) <- " dust_filter"
78
+ class(res ) <- c( " dust_filter" , " dust_likelihood " )
78
79
res
79
80
}
80
81
81
82
82
- # #' Create an independent copy of a filter. The new filter is
83
- # #' decoupled from the random number streams of the parent filter. It
84
- # #' is also decoupled from the *state size* of the parent filter, so
85
- # #' you can use this to create a new filter where the system is
86
- # #' fundamentally different but everything else is the same.
87
- # #'
88
- # #' @title Create copy of filter
89
- # #'
90
- # #' @inheritParams dust_filter_run
91
- # #'
92
- # #' @param seed The seed for the filter (see [dust_filter_create])
93
- # #'
94
- # #' @return A new `dust_filter` object
95
- dust_filter_copy <- function (filter , seed = NULL ) {
96
- dst <- new.env(parent = emptyenv())
97
- nms <- c(" inputs" , " n_particles" , " n_groups" , " deterministic" , " methods" ,
98
- " index_state" , " preserve_group_dimension" , " generator" )
99
- for (nm in nms ) {
100
- dst [[nm ]] <- filter [[nm ]]
101
- }
102
- dst $ initial_rng_state <-
103
- filter_rng_state(filter $ n_particles , filter $ n_groups , seed )
104
- class(dst ) <- " dust_filter"
105
- dst
106
- }
107
-
108
-
109
- filter_create <- function (filter , pars ) {
110
- inputs <- filter $ inputs
83
+ filter_create <- function (obj , pars ) {
84
+ inputs <- obj $ inputs
111
85
list2env(
112
- filter $ methods $ alloc(pars ,
113
- inputs $ time_start ,
114
- inputs $ time ,
115
- inputs $ dt ,
116
- inputs $ data ,
117
- inputs $ n_particles ,
118
- inputs $ n_groups ,
119
- inputs $ n_threads ,
120
- inputs $ index_state ,
121
- filter $ initial_rng_state ),
122
- filter )
123
- filter $ initial_rng_state <- NULL
124
- }
125
-
126
-
127
- # #' Run particle filter
128
- # #'
129
- # #' @title Run particle filter
130
- # #'
131
- # #' @param filter A `dust_filter` object, created by
132
- # #' [dust_filter_create]
133
- # #'
134
- # #' @param pars Optional parameters to run the filter with. If not
135
- # #' provided, parameters are not updated
136
- # #'
137
- # #' @param initial Optional initial conditions, as a matrix (state x
138
- # #' particle) or 3d array (state x particle x group). If not
139
- # #' provided, the system initial conditions are used.
140
- # #'
141
- # #' @param save_history Logical, indicating if the simulation history
142
- # #' should be saved while the simulation runs; this has a small
143
- # #' overhead in runtime and in memory. History (particle
144
- # #' trajectories) will be saved at each time in the filter. If the
145
- # #' filter was constructed using a non-`NULL` `index_state` parameter,
146
- # #' the history is restricted to these states.
147
- # #'
148
- # #' @param index_group An optional vector of group indices to run the
149
- # #' filter for. You can use this to run a subset of possible
150
- # #' groups, once the filter is initialised (this argument must be
151
- # #' `NULL` on the **first** call).
152
- # #'
153
- # #' @return A vector of likelihood values, with as many elements as
154
- # #' there are groups.
155
- # #'
156
- # #' @export
157
- dust_filter_run <- function (filter , pars , initial = NULL ,
158
- save_history = FALSE , index_group = NULL ) {
159
- check_is_dust_filter(filter )
160
- index_group <- check_index(index_group , max = filter $ n_groups ,
161
- unique = TRUE )
162
- if (! is.null(pars )) {
163
- pars <- check_pars(pars , filter $ n_groups , index_group ,
164
- filter $ preserve_group_dimension )
165
- }
166
- if (is.null(filter $ ptr )) {
167
- if (is.null(pars )) {
168
- cli :: cli_abort(" 'pars' cannot be NULL, as filter is not initialised" ,
169
- arg = " pars" )
170
- }
171
- if (! is.null(index_group )) {
172
- cli :: cli_abort(
173
- " 'index_group' must be NULL, as filter is not initialised" ,
174
- arg = " index_group" )
175
- }
176
- filter_create(filter , pars )
177
- } else if (! is.null(pars )) {
178
- filter $ methods $ update_pars(filter $ ptr , pars , index_group )
179
- }
180
- filter $ methods $ run(filter $ ptr ,
181
- initial ,
182
- save_history ,
183
- index_group ,
184
- filter $ preserve_group_dimension )
185
- }
186
-
187
-
188
- # #' Fetch the last history created by running a filter. This
189
- # #' errors if the last call to [dust_filter_run] did not use
190
- # #' `save_history = TRUE`.
191
- # #'
192
- # #' @title Fetch last filter history
193
- # #'
194
- # #' @inheritParams dust_filter_run
195
- # #'
196
- # #' @param select_random_particle Logical, indicating if we should
197
- # #' return a history for one randomly selected particle (rather than
198
- # #' the entire history). If this is `TRUE`, the particle will be
199
- # #' selected independently for each group, if the filter is grouped.
200
- # #' This option is intended to help select a representative
201
- # #' trajectory during an MCMC. When `TRUE`, we drop the `particle`
202
- # #' dimension of the return value.
203
- # #'
204
- # #' @return An array. If ungrouped this will have dimensions `state`
205
- # #' x `particle` x `time`, and if grouped then `state` x `particle`
206
- # #' x `group` x `time`. If `select_random_particle = TRUE`, the
207
- # #' second (particle) dimension will be dropped.
208
- # #'
209
- # #' @export
210
- dust_filter_last_history <- function (filter , index_group = NULL ,
211
- select_random_particle = FALSE ) {
212
- check_is_dust_filter(filter )
213
- if (is.null(filter $ ptr )) {
214
- cli :: cli_abort(c(
215
- " History is not current" ,
216
- i = " Filter has not yet been run" ))
217
- }
218
- index_group <- check_index(index_group , max = filter $ n_groups ,
219
- unique = TRUE )
220
- assert_scalar_logical(select_random_particle )
221
- filter $ methods $ last_history(filter $ ptr , index_group ,
222
- select_random_particle ,
223
- filter $ preserve_group_dimension )
224
- }
225
-
226
-
227
- # #' Get the last state from a filter.
228
- # #'
229
- # #' @title Get filter state
230
- # #'
231
- # #' @inheritParams dust_filter_last_history
232
- # #'
233
- # #' @return An array. If ungrouped this will have dimensions `state`
234
- # #' x `particle`, and if grouped then `state` x `particle` x
235
- # #' `group`. If `select_random_particle = TRUE`, the second
236
- # #' (particle) dimension will be dropped. This is the same as the
237
- # #' state returned by [dust_filter_last_history] without the time
238
- # #' dimension but also without any state index applied (i.e., we
239
- # #' always return all state).
240
- # #'
241
- # #' @export
242
- dust_filter_last_state <- function (filter , index_group = NULL ,
243
- select_random_particle = FALSE ) {
244
- check_is_dust_filter(filter )
245
- if (is.null(filter $ ptr )) {
246
- cli :: cli_abort(c(
247
- " History is not current" ,
248
- i = " Filter has not yet been run" ))
249
- }
250
- index_group <- check_index(index_group , max = filter $ n_groups ,
251
- unique = TRUE )
252
- assert_scalar_logical(select_random_particle )
253
- filter $ methods $ last_state(filter $ ptr , index_group ,
254
- select_random_particle ,
255
- filter $ preserve_group_dimension )
256
- }
257
-
258
-
259
- # #' Get random number generator (RNG) state from the particle filter.
260
- # #'
261
- # #' @title Get filter RNG state
262
- # #'
263
- # #' @inheritParams dust_filter_run
264
- # #'
265
- # #' @return A raw vector, this could be quite long. Later we will
266
- # #' describe how you might reseed a filter or system with this state.
267
- # #'
268
- # #' @export
269
- dust_filter_rng_state <- function (filter ) {
270
- check_is_dust_filter(filter )
271
- if (is.null(filter $ ptr )) {
272
- filter $ initial_rng_state
273
- } else {
274
- filter $ methods $ rng_state(filter $ ptr )
275
- }
276
- }
277
-
278
-
279
- # #' @param rng_state A raw vector of random number generator state,
280
- # #' returned by `dust_filter_rng_state`
281
- # #' @rdname dust_filter_rng_state
282
- # #' @export
283
- dust_filter_set_rng_state <- function (filter , rng_state ) {
284
- check_is_dust_filter(filter )
285
- if (is.null(filter $ ptr )) {
286
- assert_raw(rng_state , length(filter $ initial_rng_state ))
287
- filter $ initial_rng_state <- rng_state
288
- } else {
289
- filter $ methods $ set_rng_state(filter $ ptr , rng_state )
290
- }
291
- invisible ()
292
- }
293
-
294
-
295
- check_is_dust_filter <- function (filter , call = parent.frame()) {
296
- if (! inherits(filter , " dust_filter" )) {
297
- cli :: cli_abort(" Expected 'filter' to be a 'dust_filter' object" ,
298
- arg = " filter" , call = call )
299
- }
86
+ obj $ methods $ alloc(pars ,
87
+ inputs $ time_start ,
88
+ inputs $ time ,
89
+ inputs $ dt ,
90
+ inputs $ data ,
91
+ inputs $ n_particles ,
92
+ inputs $ n_groups ,
93
+ inputs $ n_threads ,
94
+ inputs $ index_state ,
95
+ obj $ initial_rng_state ),
96
+ obj )
97
+ obj $ initial_rng_state <- NULL
300
98
}
301
99
302
100
@@ -306,13 +104,3 @@ filter_rng_state <- function(n_particles, n_groups, seed) {
306
104
n_streams <- max(n_groups , 1 ) * (1 + n_particles )
307
105
monty :: monty_rng $ new(n_streams = n_streams , seed = seed )$ state()
308
106
}
309
-
310
-
311
- # #' @export
312
- print.dust_filter <- function (x , ... ) {
313
- cli :: cli_h1(" <dust_filter ({x$generator$name})>" )
314
- cli :: cli_alert_info(format_dimensions(x ))
315
- cli :: cli_bullets(c(
316
- i = " This filter runs in {x$generator$properties$time_type} time" ))
317
- invisible (x )
318
- }
0 commit comments