Skip to content

Commit d7b12cd

Browse files
authored
Merge pull request #152 from mrc-ide/mrc-6351
Support events that happen on times that are exactly hit
2 parents 0dcbf70 + 0602b34 commit d7b12cd

File tree

4 files changed

+168
-14
lines changed

4 files changed

+168
-14
lines changed

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: dust2
22
Title: Next Generation dust
3-
Version: 0.3.20
3+
Version: 0.3.21
44
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
55
email = "[email protected]"),
66
person("Imperial College of Science, Technology and Medicine",

inst/include/dust2/continuous/solver.hpp

+45-13
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,17 @@ class solver {
329329
real_type apply_events(real_type t0, real_type h, const real_type* y,
330330
const events_type<real_type>& events,
331331
ode::internals<real_type>& internals) {
332-
size_t idx_first = events.size();
333332
real_type t1 = t0 + h;
334-
real_type sign = 0;
333+
334+
// It might be worth saving this storage space in the solver, but
335+
// I doubt it matters in pratice. We need to save all the found
336+
// events and their signs, even though probably only one will be
337+
// found, because of the possibility that multiple events could
338+
// happen at the same time (e.g., two time-scheduled events or two
339+
// functions that happen to hit their roots at the same time).
340+
std::vector<bool> found(events.size());
341+
std::vector<real_type> sign(events.size());
342+
bool found_any = false;
335343

336344
for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) {
337345
const auto& e = events[idx_event];
@@ -350,20 +358,44 @@ class solver {
350358
// interpolation is expected to be quite fast and accurate.
351359
constexpr real_type eps = 1e-6;
352360
constexpr size_t steps = 100;
353-
auto root = lostturnip::find_result<real_type>(fn, t0, t1, eps, steps);
354-
idx_first = idx_event;
355-
t1 = root.x;
356-
sign = f_t0 < 0 ? 1 : -1;
361+
t1 = lostturnip::find_result<real_type>(fn, t0, t1, eps, steps).x;
362+
// Currently untested - in the case where we have two roots
363+
// that would have been crossed in this time window, the one
364+
// we are currently considering happens first so pre-empts the
365+
// previously found events.
366+
if (found_any) {
367+
std::fill(found.begin(), found.end(), false);
368+
}
369+
} else if (!(f_t1 == 0 && f_t0 != 0)) {
370+
// Consider the case where jump to a root *exactly* at t1;
371+
// this happens in coincident roots and with roots that are
372+
// based in time, and which we arrange for the solver to stop
373+
// at (e.g., while using simulate()).
374+
//
375+
// This test is the inverse of this though, because in the
376+
// case where we *don't* get an exact root we should skip the
377+
// bookkeeping below and try the next event.
378+
continue;
357379
}
358-
if (idx_first < events.size()) {
359-
internals.last.interpolate(t1, y_next_.data());
360-
events[idx_first].action(t1, sign, y_next_.data());
361-
// We need to modify the history here so that search will find
362-
// the right point.
363-
internals.last.t1 = t1;
364-
internals.events.push_back({t1, idx_first, sign});
380+
sign[idx_event] = f_t0 < 0 ? 1 : -1;
381+
found[idx_event] = true;
382+
found_any = true;
383+
}
384+
385+
// If we found at least one event, then reset the solver state
386+
// back to the point of the event and apply all the events in
387+
// turn.
388+
if (found_any) {
389+
internals.last.interpolate(t1, y_next_.data());
390+
for (size_t idx_event = 0; idx_event < events.size(); ++idx_event) {
391+
if (found[idx_event]) {
392+
events[idx_event].action(t1, sign[idx_event], y_next_.data());
393+
internals.events.push_back({t1, idx_event, sign[idx_event]});
394+
}
365395
}
396+
internals.last.t1 = t1;
366397
}
398+
367399
return t1;
368400
}
369401

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include <dust2/common.hpp>
2+
3+
// [[dust2::class(change)]]
4+
// [[dust2::time_type(continuous)]]
5+
// [[dust2::parameter(r, rank = 0)]]
6+
// [[dust2::parameter(n, rank = 0)]]
7+
// [[dust2::parameter(t_change, rank = 1)]]
8+
// [[dust2::parameter(delta, rank = 1)]]
9+
class change {
10+
public:
11+
change() = delete;
12+
13+
using real_type = double;
14+
15+
struct shared_state {
16+
real_type r;
17+
std::vector<real_type> t_change;
18+
std::vector<real_type> delta;
19+
};
20+
21+
struct internal_state {};
22+
23+
using rng_state_type = monty::random::generator<real_type>;
24+
25+
static dust2::packing packing_state(const shared_state& shared) {
26+
return dust2::packing{{"y", {}}};
27+
}
28+
29+
static void initial(real_type time,
30+
const shared_state& shared,
31+
internal_state& internal,
32+
rng_state_type& rng_state,
33+
real_type * state) {
34+
state[0] = 0;
35+
}
36+
37+
static void rhs(real_type time,
38+
const real_type * state,
39+
const shared_state& shared,
40+
internal_state& internal,
41+
real_type * state_deriv) {
42+
state_deriv[0] = shared.r;
43+
}
44+
45+
static shared_state build_shared(cpp11::list pars) {
46+
const real_type r1 = dust2::r::read_real(pars, "r");
47+
const size_t n = dust2::r::read_int(pars, "n");
48+
std::vector<real_type> t_change(n);
49+
std::vector<real_type> delta(n);
50+
dust2::r::read_real_vector(pars, n, t_change.data(), "t_change", true);
51+
dust2::r::read_real_vector(pars, n, delta.data(), "delta", false);
52+
return shared_state{r1, t_change, delta};
53+
}
54+
55+
static void update_shared(cpp11::list pars, shared_state& shared) {
56+
}
57+
58+
static auto events(const shared_state& shared, internal_state& internal) {
59+
const auto n = shared.t_change.size();
60+
dust2::ode::events_type<real_type> events;
61+
events.reserve(n);
62+
for (size_t i = 0; i < n; ++i) {
63+
auto test = [&, i](const double t, const real_type* y) {
64+
return t - shared.t_change[i];
65+
};
66+
auto action = [&, i](const double t, const double sign, double* y) {
67+
y[0] += shared.delta[i];
68+
};
69+
events.push_back(dust2::ode::event<real_type>({}, test, action));
70+
}
71+
return events;
72+
}
73+
};

tests/testthat/test-zzz-events.R

+49
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,52 @@ test_that("can run system with roots and events", {
2828
## Overall solution:
2929
expect_equal(y[1, ], cmp$y, tolerance = 1e-6)
3030
})
31+
32+
33+
test_that("can run events with events in time", {
34+
gen <- dust_compile("examples/event-time.cpp", quiet = FALSE, debug = TRUE)
35+
36+
# A mix of times that we will hit exactly and bracket
37+
pars <- list(r = 0.2, n = 3, t_change = c(2, 5.1234, 7), delta = rnorm(3))
38+
control <- dust_ode_control(
39+
debug_record_step_times = TRUE,
40+
save_history = TRUE
41+
)
42+
sys <- dust_system_create(gen, pars, ode_control = control)
43+
44+
t <- seq(0, 10, length.out = 101)
45+
y <- dust_system_simulate(sys, t)
46+
47+
info <- dust_system_internals(sys, include_history = TRUE)
48+
expect_equal(
49+
info$events[[1]],
50+
data_frame(time = pars$t_change, index = 1:3, sign = 1)
51+
)
52+
expect_equal(
53+
drop(y),
54+
t * 0.2 + c(0, cumsum(pars$delta))[findInterval(t, pars$t_change) + 1])
55+
})
56+
57+
58+
test_that("can cope with coincident events", {
59+
gen <- dust_compile("examples/event-time.cpp", quiet = FALSE, debug = TRUE)
60+
61+
pars <- list(r = 0.2, n = 3, t_change = c(2, 2, 3), delta = c(1, 3, 5))
62+
control <- dust_ode_control(
63+
debug_record_step_times = TRUE,
64+
save_history = TRUE
65+
)
66+
sys <- dust_system_create(gen, pars, ode_control = control)
67+
68+
t <- seq(0, 10, length.out = 101)
69+
y <- dust_system_simulate(sys, t)
70+
71+
info <- dust_system_internals(sys, include_history = TRUE)
72+
expect_equal(
73+
info$events[[1]],
74+
data_frame(time = pars$t_change, index = 1:3, sign = 1)
75+
)
76+
expect_equal(
77+
drop(y),
78+
t * 0.2 + c(0, cumsum(pars$delta))[findInterval(t, pars$t_change) + 1])
79+
})

0 commit comments

Comments
 (0)