diff --git a/DESCRIPTION b/DESCRIPTION index df33309f..03789f45 100755 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -67,7 +67,7 @@ Language: en-US BugReports: https://github.com/mjskay/tidybayes/issues/new URL: https://mjskay.github.io/tidybayes/, https://github.com/mjskay/tidybayes/ VignetteBuilder: knitr -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 LazyData: true Encoding: UTF-8 Collate: diff --git a/NAMESPACE b/NAMESPACE index 799244a8..6f369c37 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -20,6 +20,9 @@ S3method(epred_rvars,brmsfit) S3method(epred_rvars,default) S3method(epred_rvars,stanreg) S3method(fitted_draws,default) +S3method(flip_aes,"function") +S3method(flip_aes,character) +S3method(flip_aes,data.frame) S3method(gather_emmeans_draws,default) S3method(gather_emmeans_draws,emm_list) S3method(get_variables,default) diff --git a/NEWS.md b/NEWS.md index b0543aa4..800f72d8 100755 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,12 @@ # tidybayes (development version) -Buf fixes: +New features: + +* Add support for `draw_indices` parameter in `spread_draws()` and + `gather_draws()`. (#323) + + +Bug fixes: * Support for matrix columns in `nest_rvars()` and `unnest_rvars()`. (#316) diff --git a/R/compare_levels.R b/R/compare_levels.R index 5251d5d0..bc4147df 100755 --- a/R/compare_levels.R +++ b/R/compare_levels.R @@ -88,15 +88,7 @@ comparison_types = within(list(), { #' in the output `variable` column instead converting the unevaluated #' expression to a string. You can also use [emmeans_comparison()] to generate #' a comparison function based on contrast methods from the `emmeans` package. -#' @param draw_indices Character vector of column names in `data` that -#' should be treated as indices when making the comparison (i.e. values of -#' `variable` within each level of `by` will be compared at each -#' unique combination of levels of `draw_indices`). Columns in `draw_indices` -#' not found in `data` are ignored. The default is `c(".chain",".iteration",".draw")`, -#' which are the same names used for chain/iteration/draw indices returned by -#' [spread_draws()] or [gather_draws()]; thus if you are using `compare_levels` -#' with [spread_draws()] or [gather_draws()] you generally should not need to change this -#' value. +#' @template param-draw_indices #' @param ignore_groups character vector of names of groups to ignore by #' default in the input grouping. This is primarily provided to make it #' easier to pipe output of [add_epred_draws()] into this function, diff --git a/R/flip_aes.R b/R/flip_aes.R index 48451f33..34a2f5b9 100755 --- a/R/flip_aes.R +++ b/R/flip_aes.R @@ -29,17 +29,20 @@ flip_aes = function(x, lookup = flip_aes_lookup) { UseMethod("flip_aes") } +#' @export flip_aes.character = function(x, lookup = flip_aes_lookup) { flipped = lookup[x] x[!is.na(flipped)] = flipped[!is.na(flipped)] x } +#' @export flip_aes.data.frame = function(x, lookup = flip_aes_lookup) { names(x) = flip_aes(names(x), lookup = lookup) x } +#' @export flip_aes.function = function(x, lookup = flip_aes_lookup) { name = force(deparse(substitute(x))) function(...) { diff --git a/R/gather_draws.R b/R/gather_draws.R index 1579455d..ac03d92e 100755 --- a/R/gather_draws.R +++ b/R/gather_draws.R @@ -10,13 +10,26 @@ #' @importFrom dplyr bind_rows group_by_at #' @importFrom rlang enquos #' @export -gather_draws = function(model, ..., regex = FALSE, sep = "[, ]", ndraws = NULL, seed = NULL, n) { +gather_draws = function( + model, + ..., + regex = FALSE, + sep = "[, ]", + ndraws = NULL, + seed = NULL, + draw_indices = c(".chain", ".iteration", ".draw"), + n +) { ndraws = .Deprecated_argument_alias(ndraws, n) draws = sample_draws_from_model_(model, ndraws, seed) + draw_indices = intersect(draw_indices, names(draws)) tidysamples = lapply(enquos(...), function(variable_spec) { - gather_variables(spread_draws_(draws, variable_spec, regex = regex, sep = sep)) + gather_variables( + spread_draws_(draws, variable_spec, regex = regex, sep = sep, draw_indices = draw_indices), + exclude = c(draw_indices, ".row") + ) }) #get the groups from all the samples --- when we bind them together, diff --git a/R/spread_draws.R b/R/spread_draws.R index 8df4e755..6455ea9d 100755 --- a/R/spread_draws.R +++ b/R/spread_draws.R @@ -202,6 +202,7 @@ globalVariables(c("..")) #' @param sep Separator used to separate dimensions in variable names, as a regular expression. #' @template param-ndraws #' @template param-seed +#' @template param-draw_indices #' @template param-deprecated-n #' @return A data frame. #' @author Matthew Kay @@ -232,13 +233,29 @@ globalVariables(c("..")) #' @importFrom dplyr inner_join group_by_at #' @rdname spread_draws #' @export -spread_draws = function(model, ..., regex = FALSE, sep = "[, ]", ndraws = NULL, seed = NULL, n) { +spread_draws = function( + model, + ..., + regex = FALSE, + sep = "[, ]", + ndraws = NULL, + seed = NULL, + draw_indices = c(".chain", ".iteration", ".draw"), + n +) { ndraws = .Deprecated_argument_alias(ndraws, n) draws = sample_draws_from_model_(model, ndraws, seed) + draw_indices = intersect(draw_indices, names(draws)) tidysamples = lapply(enquos(...), function(variable_spec) { - spread_draws_(draws, variable_spec, regex = regex, sep = sep) + spread_draws_( + draws, + variable_spec, + regex = regex, + sep = sep, + draw_indices = draw_indices + ) }) #get the groups from all the samples --- when we join them together, @@ -260,7 +277,13 @@ spread_draws = function(model, ..., regex = FALSE, sep = "[, ]", ndraws = NULL, #' @importFrom dplyr mutate group_by_at #' @importFrom tidyr spread #' @importFrom rlang has_name -spread_draws_ = function(draws, variable_spec, regex = FALSE, sep = "[, ]") { +spread_draws_ = function( + draws, + variable_spec, + regex = FALSE, + sep = "[, ]", + draw_indices = c(".chain", ".iteration", ".draw") +) { #parse a variable spec in the form variable_name[dimension_name_1, dimension_name_2, ..] | wide_dimension spec = parse_variable_spec(variable_spec) variable_names = spec[[1]] @@ -268,7 +291,7 @@ spread_draws_ = function(draws, variable_spec, regex = FALSE, sep = "[, ]") { wide_dimension_name = spec[[3]] #extract the draws into a long format data frame - long_draws = spread_draws_long_(draws, variable_names, dimension_names, regex = regex, sep = sep) + long_draws = spread_draws_long_(draws, variable_names, dimension_names, regex = regex, sep = sep, draw_indices = draw_indices) #convert variable and/or dimensions back into usable data types #that were set on the model using recover_types @@ -309,7 +332,14 @@ spread_draws_ = function(draws, variable_spec, regex = FALSE, sep = "[, ]") { ## dimension_names: a character vector of dimension names #' @importFrom tidyr spread separate gather #' @importFrom dplyr summarise_all group_by_at -spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FALSE, sep = "[, ]") { +spread_draws_long_ = function( + draws, + variable_names, + dimension_names, + regex = FALSE, + sep = "[, ]", + draw_indices = c(".chain", ".iteration", ".draw") +) { if (!regex) { variable_names = escape_regex(variable_names) } @@ -326,7 +356,7 @@ spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FA } variable_names = colnames(draws)[variable_names_index] - unnest_legacy(draws[, c(".chain", ".iteration", ".draw", variable_names)]) + unnest_legacy(draws[, c(draw_indices, variable_names)]) } else { dimension_sep_regex = sep @@ -399,11 +429,11 @@ spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FA # some dimensions were requested to be nested as list columns containing arrays. # thus we have to ADD CHAIN INFO then UNNEST, then NEST DIMENSIONS then SPREAD # 2. ADD CHAIN INFO - nested_draws[[".chain_info"]] = list(draws[,c(".chain", ".iteration", ".draw")]) + nested_draws[[".chain_info"]] = list(draws[, draw_indices]) # 3. UNNEST long_draws = unnest_legacy(nested_draws) # NEST DIMENSIONS - long_draws = nest_dimensions_(long_draws, temp_dimension_names, nested_dimension_names) + long_draws = nest_dimensions_(long_draws, temp_dimension_names, nested_dimension_names, draw_indices) # 1. SPREAD long_draws = spread(long_draws, ".variable", ".value") } else { @@ -411,7 +441,7 @@ spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FA # 1. SPREAD nested_draws = spread(nested_draws, ".variable", ".value") # 2. ADD CHAIN INFO - nested_draws[[".chain_info"]] = list(draws[,c(".chain", ".iteration", ".draw")]) + nested_draws[[".chain_info"]] = list(draws[, draw_indices]) # 3. UNNEST long_draws = unnest_legacy(nested_draws) } @@ -429,7 +459,12 @@ spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FA ## dimension_names: dimensions not used for nesting ## nested_dimension_names: dimensions to be nested #' @importFrom dplyr filter summarise_at -nest_dimensions_ = function(long_draws, dimension_names, nested_dimension_names) { +nest_dimensions_ = function( + long_draws, + dimension_names, + nested_dimension_names, + draw_indices = c(".chain", ".iteration", ".draw") +) { ragged = FALSE value_name = ".value" value = as.name(value_name) @@ -443,7 +478,7 @@ nest_dimensions_ = function(long_draws, dimension_names, nested_dimension_names) } long_draws = group_by_at(long_draws, - c(".chain", ".iteration", ".draw", ".variable", dimension_names) %>% + c(draw_indices, ".variable", dimension_names) %>% # nested dimension names must come at the end of the group list # (minus the last nested dimension) so that we summarise in the # correct order diff --git a/R/tidybayes-package.R b/R/tidybayes-package.R index 937839ff..ce408c65 100644 --- a/R/tidybayes-package.R +++ b/R/tidybayes-package.R @@ -1,6 +1,5 @@ #' Tidy Data and 'Geoms' for Bayesian Models #' -#' @docType package #' @name tidybayes-package #' @aliases tidybayes #' @@ -34,4 +33,4 @@ #' Wickham, Hadley. (2014). Tidy data. _Journal of Statistical Software_, #' 59(10), 1-23. \doi{10.18637/jss.v059.i10}. #' -NULL +"_PACKAGE" diff --git a/R/ungather_draws.R b/R/ungather_draws.R index 799ed7d1..aa337a06 100755 --- a/R/ungather_draws.R +++ b/R/ungather_draws.R @@ -14,7 +14,7 @@ globalVariables(c("..dimension_values")) ungather_draws = function( data, ..., variable = ".variable", value = ".value", draw_indices = c(".chain", ".iteration", ".draw"), drop_indices = FALSE ) { - + draw_indices = intersect(draw_indices, names(data)) variable_specs = enquos(...) if (length(variable_specs) == 0) { diff --git a/R/unspread_draws.R b/R/unspread_draws.R index b73cd624..fdcd0881 100755 --- a/R/unspread_draws.R +++ b/R/unspread_draws.R @@ -20,10 +20,7 @@ globalVariables(c("..dimension_values")) #' @param data A tidy data frame of draws, such as one output by `spread_draws` or `gather_draws`. #' @param ... Expressions in the form of #' `variable_name[dimension_1, dimension_2, ...]`. See [spread_draws()]. -#' @param draw_indices Character vector of column names in `data` that -#' should be treated as indices of draws. The default is `c(".chain",".iteration",".draw")`, -#' which are the same names used for chain, iteration, and draw indices returned by -#' [spread_draws()] or [gather_draws()]. +#' @template param-draw_indices #' @param drop_indices Drop the columns specified by `draw_indices` from the resulting data frame. Default `FALSE`. #' @param variable The name of the column in `data` that contains the names of variables from the model. #' @param value The name of the column in `data` that contains draws from the variables. @@ -62,6 +59,7 @@ globalVariables(c("..dimension_values")) #' @rdname unspread_draws #' @export unspread_draws = function(data, ..., draw_indices = c(".chain", ".iteration", ".draw"), drop_indices = FALSE) { + draw_indices = intersect(draw_indices, names(data)) result = lapply(enquos(...), function(variable_spec) { unspread_draws_(data, variable_spec, draw_indices = draw_indices) diff --git a/man-roxygen/param-draw_indices.R b/man-roxygen/param-draw_indices.R new file mode 100755 index 00000000..232bdada --- /dev/null +++ b/man-roxygen/param-draw_indices.R @@ -0,0 +1,5 @@ +#' @param draw_indices Character vector of column names that should be treated +#' as indices of draws. Operations are done within combinations of these values. +#' The default is `c(".chain", ".iteration", ".draw")`, which is the same names +#' used for chain, iteration, and draw indices returned by [tidy_draws()]. +#' Names in `draw_indices` that are not found in the data are ignored. diff --git a/man/compare_levels.Rd b/man/compare_levels.Rd index 531f9b99..97b3984c 100755 --- a/man/compare_levels.Rd +++ b/man/compare_levels.Rd @@ -53,15 +53,11 @@ in the output \code{variable} column instead converting the unevaluated expression to a string. You can also use \code{\link[=emmeans_comparison]{emmeans_comparison()}} to generate a comparison function based on contrast methods from the \code{emmeans} package.} -\item{draw_indices}{Character vector of column names in \code{data} that -should be treated as indices when making the comparison (i.e. values of -\code{variable} within each level of \code{by} will be compared at each -unique combination of levels of \code{draw_indices}). Columns in \code{draw_indices} -not found in \code{data} are ignored. The default is \code{c(".chain",".iteration",".draw")}, -which are the same names used for chain/iteration/draw indices returned by -\code{\link[=spread_draws]{spread_draws()}} or \code{\link[=gather_draws]{gather_draws()}}; thus if you are using \code{compare_levels} -with \code{\link[=spread_draws]{spread_draws()}} or \code{\link[=gather_draws]{gather_draws()}} you generally should not need to change this -value.} +\item{draw_indices}{Character vector of column names that should be treated +as indices of draws. Operations are done within combinations of these values. +The default is \code{c(".chain", ".iteration", ".draw")}, which is the same names +used for chain, iteration, and draw indices returned by \code{\link[=tidy_draws]{tidy_draws()}}. +Names in \code{draw_indices} that are not found in the data are ignored.} \item{ignore_groups}{character vector of names of groups to ignore by default in the input grouping. This is primarily provided to make it diff --git a/man/reexports.Rd b/man/reexports.Rd index 598cf0d3..3e178e21 100755 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -137,7 +137,7 @@ These objects are imported from other packages. Follow the links below to see their documentation. \describe{ - \item{ggdist}{\code{\link[ggdist:theme_ggdist]{axis_titles_bottom_left}}, \code{\link[ggdist]{curve_interval}}, \code{\link[ggdist]{cut_cdf_qi}}, \code{\link[ggdist:lkjcorr_marginal]{dlkjcorr_marginal}}, \code{\link[ggdist:student_t]{dstudent_t}}, \code{\link[ggdist:theme_ggdist]{facet_title_left_horizontal}}, \code{\link[ggdist:theme_ggdist]{facet_title_right_horizontal}}, \code{\link[ggdist:tidy-format-translators]{from_broom_names}}, \code{\link[ggdist:tidy-format-translators]{from_ggmcmc_names}}, \code{\link[ggdist]{geom_dots}}, \code{\link[ggdist]{geom_dotsinterval}}, \code{\link[ggdist]{geom_interval}}, \code{\link[ggdist]{geom_lineribbon}}, \code{\link[ggdist]{geom_pointinterval}}, \code{\link[ggdist]{geom_slab}}, \code{\link[ggdist]{geom_slabinterval}}, \code{\link[ggdist:ggdist-ggproto]{GeomDots}}, \code{\link[ggdist:ggdist-ggproto]{GeomDotsinterval}}, \code{\link[ggdist:ggdist-ggproto]{GeomInterval}}, \code{\link[ggdist:ggdist-ggproto]{GeomLineribbon}}, \code{\link[ggdist:ggdist-ggproto]{GeomPointinterval}}, \code{\link[ggdist:ggdist-ggproto]{GeomSlab}}, \code{\link[ggdist:ggdist-ggproto]{GeomSlabinterval}}, \code{\link[ggdist:scales]{guide_colorbar2}}, \code{\link[ggdist:scales]{guide_colourbar2}}, \code{\link[ggdist:point_interval]{hdci}}, \code{\link[ggdist:point_interval]{hdi}}, \code{\link[ggdist]{marginalize_lkjcorr}}, \code{\link[ggdist:point_interval]{mean_hdci}}, \code{\link[ggdist:point_interval]{mean_hdi}}, \code{\link[ggdist:point_interval]{mean_qi}}, \code{\link[ggdist:point_interval]{median_hdci}}, \code{\link[ggdist:point_interval]{median_hdi}}, \code{\link[ggdist:point_interval]{median_qi}}, \code{\link[ggdist:point_interval]{Mode}}, \code{\link[ggdist:point_interval]{mode_hdci}}, \code{\link[ggdist:point_interval]{mode_hdi}}, \code{\link[ggdist:point_interval]{mode_qi}}, \code{\link[ggdist]{parse_dist}}, \code{\link[ggdist:lkjcorr_marginal]{plkjcorr_marginal}}, \code{\link[ggdist]{point_interval}}, \code{\link[ggdist:student_t]{pstudent_t}}, \code{\link[ggdist:point_interval]{qi}}, \code{\link[ggdist:lkjcorr_marginal]{qlkjcorr_marginal}}, \code{\link[ggdist:student_t]{qstudent_t}}, \code{\link[ggdist:parse_dist]{r_dist_name}}, \code{\link[ggdist:lkjcorr_marginal]{rlkjcorr_marginal}}, \code{\link[ggdist:student_t]{rstudent_t}}, \code{\link[ggdist:scales]{scale_interval_alpha_continuous}}, \code{\link[ggdist:scales]{scale_interval_alpha_discrete}}, \code{\link[ggdist:scales]{scale_interval_color_continuous}}, \code{\link[ggdist:scales]{scale_interval_color_discrete}}, \code{\link[ggdist:scales]{scale_interval_colour_continuous}}, \code{\link[ggdist:scales]{scale_interval_colour_discrete}}, \code{\link[ggdist:scales]{scale_interval_linetype_continuous}}, \code{\link[ggdist:scales]{scale_interval_linetype_discrete}}, \code{\link[ggdist:scales]{scale_interval_size_continuous}}, \code{\link[ggdist:scales]{scale_interval_size_discrete}}, \code{\link[ggdist:scales]{scale_point_alpha_continuous}}, \code{\link[ggdist:scales]{scale_point_alpha_discrete}}, \code{\link[ggdist:scales]{scale_point_color_continuous}}, \code{\link[ggdist:scales]{scale_point_color_discrete}}, \code{\link[ggdist:scales]{scale_point_colour_continuous}}, \code{\link[ggdist:scales]{scale_point_colour_discrete}}, \code{\link[ggdist:scales]{scale_point_fill_continuous}}, \code{\link[ggdist:scales]{scale_point_fill_discrete}}, \code{\link[ggdist:scales]{scale_point_size_continuous}}, \code{\link[ggdist:scales]{scale_point_size_discrete}}, \code{\link[ggdist:scales]{scale_slab_alpha_continuous}}, \code{\link[ggdist:scales]{scale_slab_alpha_discrete}}, \code{\link[ggdist:scales]{scale_slab_color_continuous}}, \code{\link[ggdist:scales]{scale_slab_color_discrete}}, \code{\link[ggdist:scales]{scale_slab_colour_continuous}}, \code{\link[ggdist:scales]{scale_slab_colour_discrete}}, \code{\link[ggdist:scales]{scale_slab_fill_continuous}}, \code{\link[ggdist:scales]{scale_slab_fill_discrete}}, \code{\link[ggdist:scales]{scale_slab_linetype_continuous}}, \code{\link[ggdist:scales]{scale_slab_linetype_discrete}}, \code{\link[ggdist:scales]{scale_slab_shape_continuous}}, \code{\link[ggdist:scales]{scale_slab_shape_discrete}}, \code{\link[ggdist:scales]{scale_slab_size_continuous}}, \code{\link[ggdist:scales]{scale_slab_size_discrete}}, \code{\link[ggdist]{stat_ccdfinterval}}, \code{\link[ggdist]{stat_cdfinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_ccdfinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_cdfinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_dots}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_dotsinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_eye}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_gradientinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_halfeye}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_interval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_lineribbon}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_pointinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_slab}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_slabinterval}}, \code{\link[ggdist]{stat_dots}}, \code{\link[ggdist]{stat_dotsinterval}}, \code{\link[ggdist]{stat_eye}}, \code{\link[ggdist]{stat_gradientinterval}}, \code{\link[ggdist]{stat_halfeye}}, \code{\link[ggdist]{stat_histinterval}}, \code{\link[ggdist]{stat_interval}}, \code{\link[ggdist]{stat_lineribbon}}, \code{\link[ggdist]{stat_pointinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_sample_slabinterval}}, \code{\link[ggdist]{stat_slab}}, \code{\link[ggdist]{stat_slabinterval}}, \code{\link[ggdist:ggdist-deprecated]{StatDistSlabinterval}}, \code{\link[ggdist:ggdist-ggproto]{StatInterval}}, \code{\link[ggdist:ggdist-ggproto]{StatPointinterval}}, \code{\link[ggdist:ggdist-deprecated]{StatSampleSlabinterval}}, \code{\link[ggdist:ggdist-ggproto]{StatSlabinterval}}, \code{\link[ggdist]{theme_ggdist}}, \code{\link[ggdist:theme_ggdist]{theme_tidybayes}}, \code{\link[ggdist:tidy-format-translators]{to_broom_names}}, \code{\link[ggdist:tidy-format-translators]{to_ggmcmc_names}}} + \item{ggdist}{\code{\link[ggdist:theme_ggdist]{axis_titles_bottom_left}}, \code{\link[ggdist]{curve_interval}}, \code{\link[ggdist]{cut_cdf_qi}}, \code{\link[ggdist:lkjcorr_marginal]{dlkjcorr_marginal}}, \code{\link[ggdist:student_t]{dstudent_t}}, \code{\link[ggdist:theme_ggdist]{facet_title_left_horizontal}}, \code{\link[ggdist:theme_ggdist]{facet_title_right_horizontal}}, \code{\link[ggdist:tidy-format-translators]{from_broom_names}}, \code{\link[ggdist:tidy-format-translators]{from_ggmcmc_names}}, \code{\link[ggdist]{geom_dots}}, \code{\link[ggdist]{geom_dotsinterval}}, \code{\link[ggdist]{geom_interval}}, \code{\link[ggdist]{geom_lineribbon}}, \code{\link[ggdist]{geom_pointinterval}}, \code{\link[ggdist]{geom_slab}}, \code{\link[ggdist]{geom_slabinterval}}, \code{\link[ggdist:ggdist-ggproto]{GeomDots}}, \code{\link[ggdist:ggdist-ggproto]{GeomDotsinterval}}, \code{\link[ggdist:ggdist-ggproto]{GeomInterval}}, \code{\link[ggdist:ggdist-ggproto]{GeomLineribbon}}, \code{\link[ggdist:ggdist-ggproto]{GeomPointinterval}}, \code{\link[ggdist:ggdist-ggproto]{GeomSlab}}, \code{\link[ggdist:ggdist-ggproto]{GeomSlabinterval}}, \code{\link[ggdist:sub-geometry-scales]{guide_colorbar2}}, \code{\link[ggdist:sub-geometry-scales]{guide_colourbar2}}, \code{\link[ggdist:point_interval]{hdci}}, \code{\link[ggdist:point_interval]{hdi}}, \code{\link[ggdist]{marginalize_lkjcorr}}, \code{\link[ggdist:point_interval]{mean_hdci}}, \code{\link[ggdist:point_interval]{mean_hdi}}, \code{\link[ggdist:point_interval]{mean_qi}}, \code{\link[ggdist:point_interval]{median_hdci}}, \code{\link[ggdist:point_interval]{median_hdi}}, \code{\link[ggdist:point_interval]{median_qi}}, \code{\link[ggdist:point_interval]{Mode}}, \code{\link[ggdist:point_interval]{mode_hdci}}, \code{\link[ggdist:point_interval]{mode_hdi}}, \code{\link[ggdist:point_interval]{mode_qi}}, \code{\link[ggdist]{parse_dist}}, \code{\link[ggdist:lkjcorr_marginal]{plkjcorr_marginal}}, \code{\link[ggdist]{point_interval}}, \code{\link[ggdist:student_t]{pstudent_t}}, \code{\link[ggdist:point_interval]{qi}}, \code{\link[ggdist:lkjcorr_marginal]{qlkjcorr_marginal}}, \code{\link[ggdist:student_t]{qstudent_t}}, \code{\link[ggdist:parse_dist]{r_dist_name}}, \code{\link[ggdist:lkjcorr_marginal]{rlkjcorr_marginal}}, \code{\link[ggdist:student_t]{rstudent_t}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_alpha_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_alpha_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_color_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_color_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_colour_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_colour_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_linetype_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_linetype_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_size_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_interval_size_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_alpha_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_alpha_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_color_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_color_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_colour_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_colour_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_fill_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_fill_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_size_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_point_size_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_alpha_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_alpha_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_color_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_color_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_colour_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_colour_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_fill_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_fill_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_linetype_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_linetype_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_shape_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_shape_discrete}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_size_continuous}}, \code{\link[ggdist:sub-geometry-scales]{scale_slab_size_discrete}}, \code{\link[ggdist]{stat_ccdfinterval}}, \code{\link[ggdist]{stat_cdfinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_ccdfinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_cdfinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_dots}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_dotsinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_eye}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_gradientinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_halfeye}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_interval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_lineribbon}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_pointinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_slab}}, \code{\link[ggdist:ggdist-deprecated]{stat_dist_slabinterval}}, \code{\link[ggdist]{stat_dots}}, \code{\link[ggdist]{stat_dotsinterval}}, \code{\link[ggdist]{stat_eye}}, \code{\link[ggdist]{stat_gradientinterval}}, \code{\link[ggdist]{stat_halfeye}}, \code{\link[ggdist]{stat_histinterval}}, \code{\link[ggdist]{stat_interval}}, \code{\link[ggdist]{stat_lineribbon}}, \code{\link[ggdist]{stat_pointinterval}}, \code{\link[ggdist:ggdist-deprecated]{stat_sample_slabinterval}}, \code{\link[ggdist]{stat_slab}}, \code{\link[ggdist]{stat_slabinterval}}, \code{\link[ggdist:ggdist-deprecated]{StatDistSlabinterval}}, \code{\link[ggdist:ggdist-ggproto]{StatInterval}}, \code{\link[ggdist:ggdist-ggproto]{StatPointinterval}}, \code{\link[ggdist:ggdist-deprecated]{StatSampleSlabinterval}}, \code{\link[ggdist:ggdist-ggproto]{StatSlabinterval}}, \code{\link[ggdist]{theme_ggdist}}, \code{\link[ggdist:theme_ggdist]{theme_tidybayes}}, \code{\link[ggdist:tidy-format-translators]{to_broom_names}}, \code{\link[ggdist:tidy-format-translators]{to_ggmcmc_names}}} \item{posterior}{\code{\link[posterior:draws_summary]{summarise_draws}}} }} diff --git a/man/spread_draws.Rd b/man/spread_draws.Rd index 693b0063..584e38f3 100755 --- a/man/spread_draws.Rd +++ b/man/spread_draws.Rd @@ -12,6 +12,7 @@ gather_draws( sep = "[, ]", ndraws = NULL, seed = NULL, + draw_indices = c(".chain", ".iteration", ".draw"), n ) @@ -22,6 +23,7 @@ spread_draws( sep = "[, ]", ndraws = NULL, seed = NULL, + draw_indices = c(".chain", ".iteration", ".draw"), n ) } @@ -41,6 +43,12 @@ regular expression and number of dimensions are included in the output. Default \item{seed}{A seed to use when subsampling draws (i.e. when \code{ndraws} is not \code{NULL}).} +\item{draw_indices}{Character vector of column names that should be treated +as indices of draws. Operations are done within combinations of these values. +The default is \code{c(".chain", ".iteration", ".draw")}, which is the same names +used for chain, iteration, and draw indices returned by \code{\link[=tidy_draws]{tidy_draws()}}. +Names in \code{draw_indices} that are not found in the data are ignored.} + \item{n}{(Deprecated). Use \code{ndraws}.} } \value{ diff --git a/man/tidybayes-package.Rd b/man/tidybayes-package.Rd index a7b5e3c7..517d6190 100644 --- a/man/tidybayes-package.Rd +++ b/man/tidybayes-package.Rd @@ -32,3 +32,21 @@ For a list of supported models, see \link{tidybayes-models}. Wickham, Hadley. (2014). Tidy data. \emph{Journal of Statistical Software}, 59(10), 1-23. \doi{10.18637/jss.v059.i10}. } +\seealso{ +Useful links: +\itemize{ + \item \url{https://mjskay.github.io/tidybayes/} + \item \url{https://github.com/mjskay/tidybayes/} + \item Report bugs at \url{https://github.com/mjskay/tidybayes/issues/new} +} + +} +\author{ +\strong{Maintainer}: Matthew Kay \email{mjskay@northwestern.edu} + +Other contributors: +\itemize{ + \item Timothy Mastny \email{tim.mastny@gmail.com} [contributor] +} + +} diff --git a/man/unspread_draws.Rd b/man/unspread_draws.Rd index edb7a970..5a4e8cd0 100755 --- a/man/unspread_draws.Rd +++ b/man/unspread_draws.Rd @@ -31,10 +31,11 @@ unspread_draws( \item{value}{The name of the column in \code{data} that contains draws from the variables.} -\item{draw_indices}{Character vector of column names in \code{data} that -should be treated as indices of draws. The default is \code{c(".chain",".iteration",".draw")}, -which are the same names used for chain, iteration, and draw indices returned by -\code{\link[=spread_draws]{spread_draws()}} or \code{\link[=gather_draws]{gather_draws()}}.} +\item{draw_indices}{Character vector of column names that should be treated +as indices of draws. Operations are done within combinations of these values. +The default is \code{c(".chain", ".iteration", ".draw")}, which is the same names +used for chain, iteration, and draw indices returned by \code{\link[=tidy_draws]{tidy_draws()}}. +Names in \code{draw_indices} that are not found in the data are ignored.} \item{drop_indices}{Drop the columns specified by \code{draw_indices} from the resulting data frame. Default \code{FALSE}.} } diff --git a/tests/testthat/test.gather_draws.R b/tests/testthat/test.gather_draws.R index bc186b6b..d5100d76 100755 --- a/tests/testthat/test.gather_draws.R +++ b/tests/testthat/test.gather_draws.R @@ -7,8 +7,6 @@ library(dplyr) library(tidyr) - - test_that("regular expressions for parameter names work on non-indexed parameters", { data(RankCorr, package = "ggdist") @@ -42,3 +40,31 @@ test_that("gather_draws works on a combination of 0 and 1-dimensional values (wi expect_equivalent(result, ref) expect_equal(group_vars(result), group_vars(ref)) }) + +test_that("draw_indices works", { + df = data.frame( + .chain = rep(1:4, each = 4), + .iteration = rep(1:4, 4), + .draw = 1:16, + .warmup = rep(c(TRUE, TRUE, FALSE, FALSE), 4), + `x[1]` = 2:17, + `x[2]` = 3:18, + check.names = FALSE + ) + + ref = tibble( + i = rep(1:2, each = 16), + .chain = rep(1:4, each = 4, times = 2), + .iteration = rep(1:4, 8), + .draw = rep(1:16, 2), + .warmup = rep(c(TRUE, TRUE, FALSE, FALSE), 8), + .variable = "x", + .value = c(2:17, 3:18) + ) %>% + group_by(i, .variable) + + result = gather_draws(df, x[i], draw_indices = c(".chain", ".iteration", ".draw", ".warmup")) + + expect_equivalent(result, ref) + expect_equal(group_vars(result), group_vars(ref)) +}) diff --git a/tests/testthat/test.spread_draws.R b/tests/testthat/test.spread_draws.R index 280b0d75..df608665 100755 --- a/tests/testthat/test.spread_draws.R +++ b/tests/testthat/test.spread_draws.R @@ -335,6 +335,33 @@ test_that("variable names containing regex special chars work", { expect_equal(spread_draws(RankCorr_t, `(Intercept)`), ref) }) +test_that("draw_indices works", { + df = data.frame( + .chain = rep(1:4, each = 4), + .iteration = rep(1:4, 4), + .draw = 1:16, + .warmup = rep(c(TRUE, TRUE, FALSE, FALSE), 4), + `x[1]` = 2:17, + `x[2]` = 3:18, + check.names = FALSE + ) + + ref = tibble( + i = rep(1:2, each = 16), + x = c(2:17, 3:18), + .chain = rep(1:4, each = 4, times = 2), + .iteration = rep(1:4, 8), + .draw = rep(1:16, 2), + .warmup = rep(c(TRUE, TRUE, FALSE, FALSE), 8) + ) %>% + group_by(i) + + result = spread_draws(df, x[i], draw_indices = c(".chain", ".iteration", ".draw", ".warmup")) + + expect_equivalent(result, ref) + expect_equal(group_vars(result), group_vars(ref)) +}) + # tests for nested syntax -------------------------------------------------