Skip to content

Commit

Permalink
ensure stan compatible names, run-extended
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Feb 25, 2025
1 parent 4c8b25a commit e7d27bb
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 43 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: dynamite
Title: Bayesian Modeling and Causal Inference for Multivariate
Longitudinal Data
Version: 1.5.5
Version: 1.5.6
Authors@R: c(
person("Santtu", "Tikka", email = "[email protected]",
role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4039-4342")),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# dynamite 1.5.6

* Variable names and factor levels are now checked and modified if needed for compatibility with Stan. Previously only response variable names were checked. It is also now possible to have spaces in variable names by quoting them.

# dynamite 1.5.5

* The package vignettes are now prerendered as some of them took a long time to build.
Expand Down
27 changes: 3 additions & 24 deletions R/dynamiteformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ dynamiteformula <- function(formula, family, link = NULL) {
dims[[1L]]$specials$resp_type <- resp_parsed$type
dims[[1L]]$response <- resp_parsed$resp
dims[[1L]]$original <- formula
dims[[1L]]$name <- parse_name(resp_parsed$resp)
dims[[1L]]$name <- stan_name(resp_parsed$resp)
} else {
dims <- parse_formula(formula, family)
if (is_binomial(family) || is_multinomial(family)) {
Expand Down Expand Up @@ -320,6 +320,7 @@ parse_formula <- function(x, family) {
rep(formula_parts, n_responses),
formula_parts
)
responses <- str_quote(responses)
formulas <- lapply(paste0(responses, "~", formula_parts), as.formula)
predictors <- lapply(
formulas,
Expand All @@ -334,13 +335,6 @@ parse_formula <- function(x, family) {
},
logical(1L)
)
# predictors <- ulapply(
# formulas,
# function(y) {
# find_nonlags(formula_rhs(y))
# }
# )
# resp_pred <- responses %in% predictors
p <- sum(resp_pred)
stopifnot_(
!any(resp_pred),
Expand All @@ -358,21 +352,6 @@ parse_formula <- function(x, family) {
)
}

#' Parse a Channel Name for a `dynamiteformula` To Be Used in Stan
#'
#' This function prepares a channel name such that it is valid for Stan. From
#' Stan Reference Manual: "A variable by itself is a well-formed expression of
#' the same type as the variable. Variables in Stan consist of ASCII strings
#' containing only the basic lower-case and upper-case Roman letters, digits,
#' and the underscore (_) character. Variables must start with a letter
#' (a--z and A--Z) and may not end with two underscores (__)"
#'
#' @param x A `character` vector.
#' @noRd
parse_name <- function(x) {
gsub("[^[:alnum:]_]+", "", x, perl = TRUE)
}

#' @rdname dynamiteformula
#' @param e1 \[`dynamiteformula`]\cr A model formula specification.
#' @param e2 \[`dynamiteformula`]\cr A model formula specification.
Expand Down Expand Up @@ -531,7 +510,7 @@ get_type_formula <- function(x, type = c("fixed", "varying", "random")) {
tr <- attr(ft, "term.labels")
rhs <- paste0(tr[idx], collapse = " + ")
rhs_out <- ifelse_(nzchar(rhs), paste0(" + ", rhs), "")
out_str <- paste0(resp, " ~ ", icpt, rhs_out)
out_str <- paste0(str_quote(deparse1(resp)), " ~ ", icpt, rhs_out)
ifelse_(
has_icpt || nzchar(rhs_out),
as.formula(out_str),
Expand Down
11 changes: 6 additions & 5 deletions R/lags.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ complete_lags <- function(x) {
xlen <- length(x)
if (identical(xlen, 2L)) {
x <- str2lang(
paste0("lag(", deparse1(x[[2L]]), ", ", "1)")
paste0("lag(", str_quote(deparse1(x[[2L]])), ", ", "1)")
)
} else if (identical(xlen, 3L)) {
k <- verify_lag(x[[3L]], deparse1(x))
x <- str2lang(
paste0("lag(", deparse1(x[[2L]]), ", ", k, ")")
paste0("lag(", str_quote(deparse1(x[[2L]])), ", ", k, ")")
)
} else {
stop_(c(
Expand Down Expand Up @@ -190,6 +190,7 @@ extract_lags <- function(x) {
if (length(lag_matches) > 0L) {
lag_map <- do.call("rbind", args = lag_matches)
lag_map <- as.data.frame(lag_map[, -1L, drop = FALSE])
lag_map$var <- str_unquote(lag_map$var)
lag_map$k <- as.integer(lag_map$k)
lag_map$k[is.na(lag_map$k)] <- 1L
lag_map$present <- TRUE
Expand Down Expand Up @@ -283,7 +284,7 @@ parse_lags <- function(dformula, data, group_var, time_var, verbose) {
for (i in seq_len(n_channels)) {
fix_rhs <- complete_lags(formula_rhs(dformula[[i]]$formula))
dformula[[i]]$formula <- as.formula(
paste0(resp_all[i], "~", deparse1(fix_rhs))
paste0(str_quote(resp_all[i]), "~", deparse1(fix_rhs))
)
}
data_names <- names(data)
Expand Down Expand Up @@ -406,7 +407,7 @@ parse_present_lags <- function(dformula, lag_map, y, i, lhs) {
dformula[[j]]$formula <- as.formula(
gsub(
pattern = lag_map$src[k],
replacement = lhs,
replacement = str_quote(lhs),
x = deparse1(dformula[[j]]$formula),
fixed = TRUE
)
Expand Down Expand Up @@ -535,7 +536,7 @@ parse_singleton_lags <- function(dformula, data, group_var,
}
}
channels[[idx]] <- dynamitechannel(
formula = as.formula(paste0(lhs, " ~ ", rhs)),
formula = as.formula(paste0(str_quote(lhs), " ~ ", str_quote(rhs))),
family = deterministic_(),
response = lhs,
specials = spec
Expand Down
10 changes: 0 additions & 10 deletions R/model_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ full_model.matrix <- function(dformula, data, group_var, fixed, verbose) {
vector(mode = "list", length = length(model_matrices)),
y_names
)
#attr(model_matrix, "assign") <- empty_list
attr(model_matrix, "fixed") <- empty_list
attr(model_matrix, "varying") <- empty_list
attr(model_matrix, "random") <- empty_list
Expand All @@ -62,15 +61,6 @@ full_model.matrix <- function(dformula, data, group_var, fixed, verbose) {
attr(model_matrix, type)[[i]] <- integer(0L)
}
}
# attr(model_matrix, "assign")[[i]] <- sort(
# unique(
# c(
# attr(model_matrix, "fixed")[[i]],
# attr(model_matrix, "varying")[[i]],
# attr(model_matrix, "random")[[i]]
# )
# )
# )
}
model_matrix
}
Expand Down
3 changes: 2 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ initialize_predict <- function(object, newdata, type, eval_type, funs, impute,
resp_stoch <- get_responses(object$dformulas$stoch)
categories <- lapply(
attr(object$stan$responses, "resp_class"),
"attr", "levels"
"attr",
"levels"
)
new_levels <- ifelse_(
length(which_random(object$dformulas$all)) == 0L,
Expand Down
2 changes: 1 addition & 1 deletion R/prepare_stan_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ prepare_stan_input <- function(dformula, data, group_var, time_var,
data[, .SD, .SDcols = resp],
function(x) {
cl <- class(x)
attr(cl, "levels") <- levels(x)
attr(cl, "levels") <- stan_name(levels(x))
cl
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/specials.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ formula_specials <- function(x, original, family) {
resp <- deparse1(formula_lhs(x))
list(
response = resp,
name = parse_name(resp),
name = stan_name(resp),
formula = x,
family = family,
original = original,
Expand Down
35 changes: 35 additions & 0 deletions R/stan_utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,41 @@ stan_supports_glm_likelihood <- function(family, backend, common_intercept) {
)
}

stan_reserved_keywords <- c(
"int", "real", "vector", "row_vector", "matrix", "ordered",
"positive_ordered", "simplex", "unit_vector", "cholesky_factor_corr",
"cholesky_factor_cov", "corr_matrix", "cov_matrix", "functions", "model",
"parameters", "transformed", "generated", "quantities", "data", "var",
"return", "if", "else", "while", "for", "in", "break", "continue", "void",
"reject", "print", "target", "T"
)

#' Ensure that a character string is a valid Stan variable name
#'
#' This function prepares a name such that it is valid for Stan. From
#' Stan Reference Manual: "A variable by itself is a well-formed expression of
#' the same type as the variable. Variables in Stan consist of ASCII strings
#' containing only the basic lower-case and upper-case Roman letters, digits,
#' and the underscore (_) character. Variables must start with a letter
#' (a--z and A--Z) and may not end with two underscores (__)". Adds a prefix
#' when the first character is not a letter and a suffix for reserved keywords.
#'
#' @param x A `character` vector.
#' @noRd
stan_name <- function(x) {
x <- gsub("\\s+", "_", x)
x <- gsub("[^a-zA-Z0-9_]", "", x)
x <- gsub("_{2,}$", "", x)
for (i in seq_along(x)) {
if (!grepl("^[a-zA-Z]", x[i])) {
x[i] <- paste0("v_", x[i])
}
if (tolower(x[i]) %in% stan_reserved_keywords) {
x[i] <- paste0(x[i], "_var")
}
}
x
}

# Wrapper methods for backends --------------------------------------------

Expand Down
26 changes: 26 additions & 0 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,32 @@ cs <- function(...) {
paste0(c(...), collapse = ", ")
}

#' Quote strings with spaces
#'
#' @param x A `character` vector.
#' @noRd
str_quote <- function(x) {
vapply(
x,
function(y) {
ifelse_(
grepl("\\s+", y),
paste0("`", y, "`"),
y
)
},
character(1L)
)
}

#' Unquote strings
#'
#' @param x A `character` vector.
#' @noRd
str_unquote <- function(x) {
gsub("^`(.+)`$", "\\1", x)
}

#' Create a Comma-separated Character String and Evaluate with glue
#'
#' @param ... `character` strings.
Expand Down
1 change: 1 addition & 0 deletions dynamite.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: df876877-9fe4-4357-8bc3-21471ccee2b4

RestoreWorkspace: No
SaveWorkspace: No
Expand Down

0 comments on commit e7d27bb

Please sign in to comment.