Skip to content

Commit f272dd6

Browse files
committed
Update test helper
1 parent 275b5bb commit f272dd6

File tree

6 files changed

+75
-55
lines changed

6 files changed

+75
-55
lines changed

R/cpp11.R

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R/test.R

+5-14
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
11
test_trajectories <- function(time, state, order = NULL,
2-
index_state = NULL, index_group = NULL,
3-
select_particle = NULL, reorder = FALSE) {
2+
index_state = NULL, index_group = NULL,
3+
select_particle = NULL, times_snapshot = NULL,
4+
save_state = TRUE, reorder = FALSE) {
45
test_trajectories_(time, state, order,
5-
index_state, index_group, select_particle, reorder)
6-
}
7-
8-
9-
test_snapshots <- function(time, save_snapshots, state, order = NULL,
10-
index_group = NULL, select_particle = NULL,
11-
reorder = FALSE) {
12-
if (is.numeric(save_snapshots)) {
13-
save_snapshots <- time %in% save_snapshots
14-
}
15-
test_snapshots_(time, save_snapshots, state, order,
16-
index_group, select_particle, reorder)
6+
index_state, index_group, select_particle,
7+
times_snapshot, save_state, reorder)
178
}

inst/include/dust2/trajectories.hpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ class trajectories {
107107
return position_state_;
108108
}
109109

110+
auto n_snapshots() const {
111+
return position_snapshot_;
112+
}
113+
110114
auto& index_group() const {
111115
return index_group_;
112116
}
@@ -245,7 +249,7 @@ class trajectories {
245249
std::copy_n(iter_src_j, n_state_total_, iter_dest_j);
246250
} else {
247251
auto iter_dest_j = iter_dest + j * n_state_total_;
248-
for (size_t k = 0; k < n_particles; ++k) {
252+
for (size_t k = 0; k < n_particles_; ++k) {
249253
auto iter_src_k = iter_src +
250254
j * n_state_total_ * n_particles_ +
251255
index_particle[j * n_particles_ + k] * n_state_total_;
@@ -263,13 +267,13 @@ class trajectories {
263267

264268
if (reorder_[i]) {
265269
const auto iter_order = order_.begin() + i * len_order_;
266-
for (size_t j = 0; j < n_groups; ++j) {
270+
for (size_t j = 0; j < n_groups_; ++j) {
267271
const auto iter_order_j = iter_order + j * n_particles_;
268272
if (use_select_particle) {
269273
index_particle[j] = *(iter_order_j + index_particle[j]);
270274
} else {
271275
const auto index_particle_j = index_particle.begin() + j * n_particles_;
272-
for (size_t k = 0; k < n_groups; ++k) {
276+
for (size_t k = 0; k < n_groups_; ++k) {
273277
const auto index_particle_k = index_particle_j + k;
274278
*index_particle_k = *(iter_order_j + *index_particle_k);
275279
}

src/cpp11.cpp

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/test.cpp

+36-11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ cpp11::sexp test_trajectories_(cpp11::doubles r_time,
3131
cpp11::sexp r_index_state,
3232
cpp11::sexp r_index_group,
3333
cpp11::sexp r_select_particle,
34+
cpp11::sexp r_times_snapshot,
35+
bool save_state,
3436
bool reorder) {
3537
const size_t n_times = r_time.size();
3638
cpp11::sexp el0 = r_state[0];
@@ -57,9 +59,11 @@ cpp11::sexp test_trajectories_(cpp11::doubles r_time,
5759
"select_particle");
5860
}
5961

62+
const auto times_snapshot = r_times_snapshot == R_NilValue ?
63+
std::vector<double>() :
64+
cpp11::as_cpp<std::vector<double>>(r_times_snapshot);
65+
6066
dust2::trajectories<double> h(n_state, n_particles, n_groups, n_times);
61-
const auto save_state = true;
62-
const std::vector<double> times_snapshot;
6367
h.set_index_and_reset(index_state, index_group, save_state, times_snapshot);
6468
for (size_t i = 0; i < static_cast<size_t>(r_state.size()); ++i) {
6569
if (r_order == R_NilValue) {
@@ -79,16 +83,37 @@ cpp11::sexp test_trajectories_(cpp11::doubles r_time,
7983
const auto n_groups_out = h.n_groups();
8084
const auto n_times_out = h.n_times();
8185
cpp11::writable::doubles ret_time(static_cast<int>(n_times_out));
82-
const size_t len = n_state_out * n_particles_out * n_groups_out * n_times_out;
83-
cpp11::writable::doubles ret_state(static_cast<int>(len));
84-
h.export_time(REAL(ret_time));
85-
h.export_state(REAL(ret_state), reorder, select_particle);
8686

87-
if (use_select_particle) {
88-
dust2::r::set_array_dims(ret_state, {n_state_out, n_particles_out * n_groups_out, n_times_out});
89-
} else {
90-
dust2::r::set_array_dims(ret_state, {n_state_out, n_particles_out, n_groups_out, n_times_out});
87+
cpp11::sexp ret_state = R_NilValue;
88+
cpp11::sexp ret_snapshots = R_NilValue;
89+
90+
if (save_state) {
91+
const size_t len = n_state_out * n_particles_out * n_groups_out * n_times_out;
92+
cpp11::writable::doubles arr(static_cast<int>(len));
93+
h.export_time(REAL(ret_time));
94+
h.export_state(REAL(arr), reorder, select_particle);
95+
96+
if (use_select_particle) {
97+
dust2::r::set_array_dims(arr, {n_state_out, n_particles_out * n_groups_out, n_times_out});
98+
} else {
99+
dust2::r::set_array_dims(arr, {n_state_out, n_particles_out, n_groups_out, n_times_out});
100+
}
101+
ret_state = arr;
102+
}
103+
104+
if (!times_snapshot.empty()) {
105+
const auto n_snapshots = h.n_snapshots();
106+
const size_t len = n_state * n_particles_out * n_groups_out * n_snapshots;
107+
cpp11::writable::doubles arr(static_cast<int>(len));
108+
h.export_snapshots(REAL(arr), reorder, select_particle);
109+
110+
if (use_select_particle) {
111+
dust2::r::set_array_dims(arr, {n_state, n_particles_out * n_groups_out, n_snapshots});
112+
} else {
113+
dust2::r::set_array_dims(arr, {n_state, n_particles_out, n_groups_out, n_snapshots});
114+
}
115+
ret_snapshots = arr;
91116
}
92117

93-
return cpp11::writable::list{ret_time, ret_state};
118+
return cpp11::writable::list{ret_time, ret_state, ret_snapshots};
94119
}

tests/testthat/test-filter-details.R

+21-21
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ test_that("can use trajectories", {
3535
s_arr <- array(unlist(s), c(n_state, n_particles, n_groups, n_time))
3636

3737
expect_equal(test_trajectories(time, s, reorder = TRUE),
38-
list(time, s_arr))
38+
list(time, s_arr, NULL))
3939
expect_equal(test_trajectories(time, s, reorder = FALSE),
40-
list(time, s_arr))
40+
list(time, s_arr, NULL))
4141
expect_equal(test_trajectories(time, s[1:3], reorder = TRUE),
42-
list(time[1:3], s_arr[, , , 1:3]))
42+
list(time[1:3], s_arr[, , , 1:3], NULL))
4343
expect_equal(test_trajectories(time, s, order = vector("list", length(time)),
4444
reorder = TRUE),
45-
list(time, s_arr))
45+
list(time, s_arr, NULL))
4646

4747
res <- test_trajectories(time, s, select_particle = c(6, 4, 2))[[2]]
4848
expect_equal(dim(res), c(n_state, n_groups, n_time))
@@ -94,21 +94,21 @@ test_that("can reorder trajectories with no groups", {
9494
## Pass in, but ignore index
9595
expect_equal(
9696
test_trajectories(time, state, order = order, reorder = FALSE),
97-
list(time, state_arr))
97+
list(time, state_arr, NULL))
9898

9999
## Really simple, add an index that does not reorder anything:
100100
expect_equal(
101101
test_trajectories(time, state[1], order = order[1], reorder = TRUE),
102-
list(time[1], state_arr[, , , 1, drop = FALSE]))
102+
list(time[1], state_arr[, , , 1, drop = FALSE], NULL))
103103
expect_equal(
104104
test_trajectories(time, state[1:2], order = list(NULL, 0:6),
105105
reorder = TRUE),
106-
list(time[1:2], state_arr[, , , 1:2, drop = FALSE]))
106+
list(time[1:2], state_arr[, , , 1:2, drop = FALSE], NULL))
107107

108108
## Proper reordering with the full index:
109109
expect_equal(
110110
test_trajectories(time, state, order = order, reorder = TRUE),
111-
list(time, true))
111+
list(time, true, NULL))
112112

113113
expect_equal(
114114
test_trajectories(time, state, order = order, reorder = FALSE,
@@ -149,9 +149,9 @@ test_that("can reorder trajectories on the way out", {
149149

150150
state_arr <- array(unlist(state), dim(true))
151151
expect_equal(test_trajectories(time, state, order = order, reorder = FALSE),
152-
list(time, state_arr))
152+
list(time, state_arr, NULL))
153153
expect_equal(test_trajectories(time, state, order = order, reorder = TRUE),
154-
list(time, true))
154+
list(time, true, NULL))
155155
})
156156

157157

@@ -171,14 +171,14 @@ test_that("can extract trajectories with group index, no reordering", {
171171
s_arr <- array(unlist(s), c(n_state, n_particles, n_groups, n_time))
172172

173173
expect_equal(test_trajectories(time, s, index_group = NULL),
174-
list(time, s_arr))
174+
list(time, s_arr, NULL))
175175
expect_equal(test_trajectories(time, s, index_group = seq_len(n_groups)),
176-
list(time, s_arr))
176+
list(time, s_arr, NULL))
177177

178178
expect_equal(test_trajectories(time, s, index_group = 2),
179-
list(time, s_arr[, , 2, , drop = FALSE]))
179+
list(time, s_arr[, , 2, , drop = FALSE], NULL))
180180
expect_equal(test_trajectories(time, s, index_group = c(3, 1)),
181-
list(time, s_arr[, , c(3, 1), , drop = FALSE]))
181+
list(time, s_arr[, , c(3, 1), , drop = FALSE], NULL))
182182

183183
m <- test_trajectories(time, s, select_particle = c(6, 4, 2))[[2]]
184184
expect_equal(dim(m), c(n_state, n_groups, n_time))
@@ -226,34 +226,34 @@ test_that("can reorder trajectories on the way out", {
226226
state_arr <- array(unlist(state), dim(true))
227227
expect_equal(
228228
test_trajectories(time, state, order = order, reorder = TRUE),
229-
list(time, true))
229+
list(time, true, NULL))
230230
expect_equal(
231231
test_trajectories(time, state, order = order, reorder = TRUE,
232232
index_group = 1:3),
233-
list(time, true))
233+
list(time, true, NULL))
234234

235235
expect_equal(
236236
test_trajectories(time, state, order = order, reorder = TRUE,
237237
index_group = 1),
238-
list(time, true[, , 1, , drop = FALSE]))
238+
list(time, true[, , 1, , drop = FALSE], NULL))
239239
expect_equal(
240240
test_trajectories(time, state, order = order, reorder = TRUE,
241241
index_group = 2),
242-
list(time, true[, , 2, , drop = FALSE]))
242+
list(time, true[, , 2, , drop = FALSE], NULL))
243243
expect_equal(
244244
test_trajectories(time, state, order = order, reorder = TRUE,
245245
index_group = 3),
246-
list(time, true[, , 3, , drop = FALSE]))
246+
list(time, true[, , 3, , drop = FALSE], NULL))
247247

248248
expect_equal(
249249
test_trajectories(time, state, order = order, reorder = TRUE,
250250
index_group = 3:2),
251-
list(time, true[, , 3:2, , drop = FALSE]))
251+
list(time, true[, , 3:2, , drop = FALSE], NULL))
252252

253253
expect_equal(
254254
test_trajectories(time, state, order = order, reorder = TRUE,
255255
index_group = 3:1),
256-
list(time, true[, , 3:1, ]))
256+
list(time, true[, , 3:1, ], NULL))
257257

258258
m <- test_trajectories(time, state, order = order, reorder = TRUE,
259259
select_particle = c(6, 4, 2))[[2]]

0 commit comments

Comments
 (0)