Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve early calculation #903

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.
- A bug was fixed where `plot.dist_spec()` wasn't throwing an informative error due to an incomplete check for the max of the specified delay. By @jamesmbaazam in #858 and reviewed by @.
- Updated the early dynamics calculation to use the full linear model if available. Also changd the prior for initial infections to be approximately Poisson. By @sbfnk in # and reviewed by

## Package changes

Expand Down
30 changes: 18 additions & 12 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -455,28 +455,34 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#' @return A list containing `prior_infections` and `prior_growth`.
#' @keywords internal
estimate_early_dynamics <- function(cases, seeding_time) {
first_week <- data.table::data.table(
confirm = cases[seq_len(min(7, length(cases)))],
t = seq_len(min(7, length(cases)))
initial_period <- data.table::data.table(
confirm = cases[seq_len(min(7, seeding_time, length(cases)))],
t = seq_len(min(7, seeding_time, length(cases))) - 1
)[!is.na(confirm)]

# Calculate prior infections
prior_infections <- log(mean(first_week$confirm, na.rm = TRUE))
prior_infections <- ifelse(
is.na(prior_infections) || is.null(prior_infections),
0, prior_infections
)

prior_infections <- 0
# Calculate prior growth
if (seeding_time > 1 && nrow(first_week) > 1) {
if (seeding_time > 1 && nrow(initial_period) > 1) {
safe_lm <- purrr::safely(stats::lm)
prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
prior_growth <- safe_lm(log(confirm) ~ t, data = initial_period)[[1]]
prior_infections <- ifelse(
is.null(prior_growth), 0, prior_growth$coefficients[1]
)
prior_growth <- ifelse(
is.null(prior_growth), 0, prior_growth$coefficients[2]
)
} else {
prior_growth <- 0
}

# Calculate prior infections
if (prior_infections == 0) {
prior_infections <- log(mean(initial_period$confirm, na.rm = TRUE))
if (is.na(prior_infections) || is.null(prior_infections)) {
prior_infections <- 0
}
}

return(list(
prior_infections = prior_infections,
prior_growth = prior_growth
Expand Down
16 changes: 12 additions & 4 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ parameters {
vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise
// Rt
array[estimate_r] real initial_infections; // seed infections
array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate
array[bp_n > 0 ? 1 : 0] real<lower = 0> bp_sd; // standard deviation of breakpoint effect
vector[bp_n] bp_effects; // Rt breakpoint effects
// observation model
Expand All @@ -62,6 +61,12 @@ transformed parameters {
vector[ot_h] reports; // estimated reported cases
vector[ot] obs_reports; // observed estimated reported cases
vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf;
array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate

if (num_elements(initial_growth) > 0) {
initial_growth[1] = prior_growth +
(prior_infections - initial_infections[1]) / seeding_time;
}

// GP in noise - spectral densities
profile("update gp") {
Expand Down Expand Up @@ -95,9 +100,13 @@ transformed parameters {
);
}
profile("infections") {
real frac_obs = get_param(
frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time
future_time, obs_scale, frac_obs
);
}
} else {
Expand Down Expand Up @@ -201,8 +210,7 @@ model {
// priors on Rt
profile("rt lp") {
rt_lp(
initial_infections, initial_growth, bp_effects, bp_sd, bp_n,
seeding_time, prior_infections, prior_growth
initial_infections, bp_effects, bp_sd, bp_n, prior_infections
);
}
}
Expand Down
5 changes: 4 additions & 1 deletion inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ real update_infectiousness(vector infections, vector gt_rev_pmf,
// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections)
vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
array[] real initial_infections, array[] real initial_growth,
int pop, int ht) {
int pop, int ht, int obs_scale, real frac_obs) {
// time indices and storage
int ot = num_elements(oR);
int nht = ot - ht;
Expand All @@ -32,6 +32,9 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
vector[ot] infectiousness;
// Initialise infections using daily growth
infections[1] = exp(initial_infections[1]);
if (obs_scale) {
infections[1] = infections[1] / frac_obs;
}
if (uot > 1) {
real growth = exp(initial_growth[1]);
for (s in 2:uot) {
Expand Down
10 changes: 3 additions & 7 deletions inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,14 @@ vector update_Rt(int t, real R0, vector noise, array[] int bps,
* @param prior_infections Prior mean for initial infections
* @param prior_growth Prior mean for initial growth rates
*/
void rt_lp(array[] real initial_infections, array[] real initial_growth,
vector bp_effects, array[] real bp_sd, int bp_n, int seeding_time,
real prior_infections, real prior_growth) {
void rt_lp(array[] real initial_infections, vector bp_effects,
array[] real bp_sd, int bp_n, real prior_infections) {
//breakpoint effects on Rt
if (bp_n > 0) {
bp_sd[1] ~ normal(0, 0.1) T[0,];
bp_effects ~ normal(0, bp_sd[1]);
}
// initial infections
initial_infections ~ normal(prior_infections, 0.2);
initial_infections ~ normal(prior_infections, sqrt(prior_infections));

if (seeding_time > 1) {
initial_growth ~ normal(prior_growth, 0.2);
}
}
2 changes: 1 addition & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ generated quantities {

infections[i] = to_row_vector(generate_infections(
to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i],
initial_growth[i], pop, future_time
initial_growth[i], pop, future_time, obs_scale, frac_obs[i]
));

if (delay_id) {
Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test-estimate-early-dynamics.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test_that("estimate_early_dynamics works", {
# Check values
expect_identical(
round(prior_estimates$prior_infections, 2),
4.53
3.21
)
expect_identical(
round(prior_estimates$prior_growth, 2),
Expand All @@ -21,29 +21,29 @@ test_that("estimate_early_dynamics works", {
test_that("estimate_early_dynamics handles NA values correctly", {
cases <- c(10, 20, NA, 40, 50, NA, 70)
prior_estimates <- estimate_early_dynamics(cases, 7)
expect_equal(
prior_estimates$prior_infections,
log(mean(c(10, 20, 40, 50, 70), na.rm = TRUE))
expect_identical(
round(prior_estimates$prior_infections, 2),
2.55
)
expect_true(!is.na(prior_estimates$prior_growth))
})

test_that("estimate_early_dynamics handles exponential growth", {
cases <- 2^(c(0:6)) # Exponential growth
prior_estimates <- estimate_early_dynamics(cases, 7)
expect_equal(prior_estimates$prior_infections, log(mean(cases[1:7])))
expect_equal(prior_estimates$prior_infections, log(2^0))
expect_true(prior_estimates$prior_growth > 0) # Growth should be positive
})

test_that("estimate_early_dynamics handles exponential decline", {
cases <- rev(2^(c(0:6))) # Exponential decline
prior_estimates <- estimate_early_dynamics(cases, 7)
expect_equal(prior_estimates$prior_infections, log(mean(cases[1:7])))
expect_equal(prior_estimates$prior_infections, log(2^6))
expect_true(prior_estimates$prior_growth < 0) # Growth should be negative
})

test_that("estimate_early_dynamics correctly handles seeding time less than 2", {
cases <- c(5, 10, 20) # Less than 7 days of data
prior_estimates <- estimate_early_dynamics(cases, 1)
expect_equal(prior_estimates$prior_growth, 0) # Growth should be 0 if seeding time is <= 1
})
})
16 changes: 8 additions & 8 deletions tests/testthat/test-stan-infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,35 @@ gt_rev_pmf <- get_delay_rev_pmf(
# test generate infections
test_that("generate_infections works as expected", {
expect_equal(
round(generate_infections(c(1, rep(1, 9)), 10, gt_rev_pmf, log(1000), 0, 0, 0), 0),
round(generate_infections(c(1, rep(1, 9)), 10, gt_rev_pmf, log(1000), 0, 0, 0, 0, 0), 0),
c(rep(1000, 10), 995, 996, rep(997, 8))
)
expect_equal(
round(generate_infections(c(1, rep(1.1, 9)), 10, gt_rev_pmf, log(20), 0.03, 0, 0), 0),
round(generate_infections(c(1, rep(1.1, 9)), 10, gt_rev_pmf, log(20), 0.03, 0, 0, 0, 0), 0),
c(20, 21, 21, 22, 23, 23, 24, 25, 25, 26, 24, 27, 28, 29, 30, 30, 31, 32, 33, 34)
)
expect_equal(
round(generate_infections(c(1, rep(1.1, 9)), 10, gt_rev_pmf, log(100), 0, 0, 0), 0),
round(generate_infections(c(1, rep(1.1, 9)), 10, gt_rev_pmf, log(100), 0, 0, 0, 0, 0), 0),
c(rep(100, 10), 99, 110, 112, 115, 119, 122, 126, 130, 134, 138)
)
expect_equal(
round(generate_infections(c(1, rep(1, 9)), 4, gt_rev_pmf, log(500), -0.02, 0, 0), 0),
round(generate_infections(c(1, rep(1, 9)), 4, gt_rev_pmf, log(500), -0.02, 0, 0, 0, 0), 0),
c(500, 490, 480, 471, 382, 403, 408, rep(409, 7))
)
expect_equal(
round(generate_infections(c(1, rep(1.1, 9)), 4, gt_rev_pmf, log(500), 0, 0, 0), 0),
round(generate_infections(c(1, rep(1.1, 9)), 4, gt_rev_pmf, log(500), 0, 0, 0, 0, 0), 0),
c(rep(500, 4), 394, 460, 475, 489, 505, 520, 536, 553, 570, 588)
)
expect_equal(
round(generate_infections(c(1, rep(1, 9)), 1, gt_rev_pmf, log(40), numeric(0), 0, 0), 0),
round(generate_infections(c(1, rep(1, 9)), 1, gt_rev_pmf, log(40), numeric(0), 0, 0, 0, 0), 0),
c(40, 8, 11, 12, 12, rep(13, 6))
)
expect_equal(
round(generate_infections(c(1, rep(1.1, 9)), 1, gt_rev_pmf, log(100), 0.01, 0, 0), 0),
round(generate_infections(c(1, rep(1.1, 9)), 1, gt_rev_pmf, log(100), 0.01, 0, 0, 0, 0), 0),
c(100, 20, 31, 35, 36, 37, 38, 39, 41, 42, 43)
)
expect_equal(
round(generate_infections(c(1, rep(1, 9)), 10, gt_rev_pmf, log(1000), 0, 100000, 4), 0),
round(generate_infections(c(1, rep(1, 9)), 10, gt_rev_pmf, log(1000), 0, 100000, 4, 0, 0), 0),
c(rep(1000, 10), 995, 996, rep(997, 4), 980, 965, 947, 926)
)
})
4 changes: 2 additions & 2 deletions vignettes/estimate_infections.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ These infections are then mapped to observations via discrete convolutions with
The model is initialised before the first observed data point by assuming constant exponential growth for the mean of modelled delays from infection to case report (called `seeding_time` $t_\mathrm{seed}$ in the model):

\begin{align}
I_0 &\sim \mathrm{LogNormal}(I_\mathrm{obs}, 0.2) \\
I_0 &\sim \mathrm{LogNormal}(I_\mathrm{obs}, \sqrt(I_\mathrm{obs})) \\
r &\sim \mathrm{Normal}(r_\mathrm{obs}, 0.2)\\
I_{0 < t \leq t_\mathrm{seed}} &= I_0 \exp \left(r t \right)
\end{align}

where $I_{t}$ is the number of latent infections on day $t$, $r$ is the estimate of the initial growth rate, and $I_\mathrm{obs}$ and $r_\mathrm{obs}$ are estimated from the first week of observed data: $I_\mathrm{obs}$ as the mean of reported cases in the first 7 days (or the mean of all cases if fewer than 7 days of data are given), divided by the prior mean reporting fraction if less than 1 (see [Delays and scaling]); and $r_\mathrm{obs}$ as the point estimate from fitting a linear regression model to the first 7 days of data (or all data if fewer than 7 days of data are given),
where $I_{t}$ is the number of latent infections on day $t$, $r$ is the estimate of the initial growth rate, and $I_\mathrm{obs}$ and $r_\mathrm{obs}$ are estimated from the first week of observed data, respectively, as as the point estimates of intercept and slope from fitting a linear regression model to the first 7 days of data (or all data if fewer than 7 days of data are given),

\begin{equation}
log(C_t) = a + r_\mathrm{obs} t + \epsilon_t
Expand Down
Loading