Skip to content

Commit

Permalink
no stanfit for get_parameter_dims(), run-extended
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Aug 13, 2024
1 parent 54dc55d commit 60b1564
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 48 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# dynamite 1.5.4

* Obtaining the model parameter dimensions via `get_parameter_dims()` no longer requires a compiled Stan model. This leads to a significant performance improvement when applied to `dynamiteformula` objects.
* Model fitting using `cmdstanr` backend no longer relies on `rstan::read_stan_csv()` to construct the fit object. Instead, the resulting `CmdStanMCMC` object is used directly. This should provide a substantial performance improvement in some instances. For `dynamice()`, samples from different imputed datasets are combined using `cmdstanr::as_cmdstan_fit()` instead.

# dynamite 1.5.3
Expand Down
91 changes: 43 additions & 48 deletions R/getters.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ get_data.dynamiteformula <- function(x, data, time, group = NULL, ...) {
#' @rdname get_data
#' @export
get_data.dynamitefit <- function(x, ...) {
if (!is.null(x$stan_input)) {
return(x$stan_input$sampling_vars)
}
out <- dynamite(
dformula = eval(formula(x)),
data = x$data,
Expand Down Expand Up @@ -232,64 +235,56 @@ get_parameter_dims <- function(x, ...) {
#' @export
get_parameter_dims.dynamiteformula <- function(x, data, time,
group = NULL, ...) {
out <- try(
suppressWarnings(
dynamite(
dformula = x,
data = data,
time = time,
group = group,
algorithm = "Fixed_param",
chains = 1,
iter = 1,
refresh = 0,
backend = "rstan",
verbose_stan = FALSE,
...
)
),
silent = TRUE
)
stopifnot_(
!inherits(out, "try-error"),
c(
"Unable to determine parameter dimensions:",
`x` = attr(out, "condition")$message
)
out <- dynamite(
dformula = x,
data = data,
time = time,
group = group,
debug = list(no_compile = TRUE, stan_input = TRUE, model_code = FALSE),
...
)
get_parameter_dims(out)
get_parameter_dims(out, ...)
}

#' @rdname get_parameter_dims
#' @export
get_parameter_dims.dynamitefit <- function(x, ...) {
stopifnot_(
!is.null(x$stanfit),
"No Stan model fit is available."
)
if (x$backend == "cmdstanr") {
return(
get_parameter_dims.dynamiteformula(
x = eval(formula(x)),
data = x$data,
time = x$time_var,
group = x$group_var,
...
)
)
}
pars_text <- get_code(x, blocks = "parameters")
pars <- get_parameters(pars_text)
# TODO no inits
out <- rstan::get_inits(x$stanfit)[[1L]]
out <- out[names(out) %in% pars]
lapply(
out,
pars_text <- strsplit(pars_text, split = "\n")[[1L]]
pars_text <- pars_text[grepl(";", pars_text)]
par_regex <- regexec(
pattern = "^.+\\s([^\\s]+);.*$",
text = pars_text,
perl = TRUE
)
par_matches <- regmatches(pars_text, par_regex)
par_names <- vapply(par_matches, "[[", character(1L), 2L)
dim_regex <- regexec(
pattern = "^[^\\[]+\\[([^\\]]+)\\].+",
text = pars_text,
perl = TRUE
)
dim_matches <- regmatches(pars_text, dim_regex)
dim_names <- lapply(
dim_matches,
function(y) {
d <- dim(y)
ifelse_(is.null(d), 1L, d)
if (length(y) > 0L) {
paste0("c(", y[2L], ")")
} else {
"1"
}
}
)
e <- list2env(get_data(x, ...))
stats::setNames(
lapply(
dim_names,
function(y) {
eval(str2lang(y), envir = e)
}
),
par_names
)
}

#' Internal Parameter Block Variable Name Extraction
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/test-output.R
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,20 @@ test_that("gets can be got", {
a_y_c = 1L
)
)
stanfit_dims <- gaussian_example_fit$stanfit@par_dims
stanfit_dims[lengths(stanfit_dims) == 0] <- 1L
gaussian_dims <- get_parameter_dims(gaussian_example_fit)
stanfit_dims <- stanfit_dims[names(gaussian_dims)]
expect_equal(gaussian_dims, stanfit_dims)
expect_equal(
get_parameter_dims(
obs(y ~ -1 + z + varying(~ x + lag(y)) + random(~1),
family = "gaussian"
) + random_spec() + splines(df = 20),
gaussian_example, time = "time", group = "id"
),
gaussian_dims
)
})

test_that("credible intervals can be computed", {
Expand Down

0 comments on commit 60b1564

Please sign in to comment.