Skip to content

Commit 527c48c

Browse files
Merge pull request #349 from stan-dev/fun-avg
add `stat` argument to `ppc_*_avg` functions.
2 parents 95a23b7 + 9cb807c commit 527c48c

18 files changed

+1483
-1239
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Title: Plotting for Bayesian Models
44
Version: 1.13.0.9000
55
Date: 2025-06-18
66
Authors@R: c(person("Jonah", "Gabry", role = c("aut", "cre"), email = "[email protected]"),
7-
person("Tristan", "Mahr", role = "aut"),
7+
person("Tristan", "Mahr", role = "aut", comment = c(ORCID = "0000-0002-8890-5116")),
88
person("Paul-Christian", "Bürkner", role = "ctb"),
99
person("Martin", "Modrák", role = "ctb"),
1010
person("Malcolm", "Barrett", role = "ctb"),
@@ -26,7 +26,7 @@ URL: https://mc-stan.org/bayesplot/
2626
BugReports: https://github.com/stan-dev/bayesplot/issues/
2727
SystemRequirements: pandoc (>= 1.12.3), pandoc-citeproc
2828
Depends:
29-
R (>= 3.1.0)
29+
R (>= 4.1.0)
3030
Imports:
3131
dplyr (>= 0.8.0),
3232
ggplot2 (>= 3.4.0),

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# bayesplot (development version)
22

3+
* PPC "avg" functions (`ppc_scatter_avg()`, `ppc_error_scatter_avg()`, etc.) gain a `stat` argument to set the averaging function. (Suggestion of #348, @kruschke).
4+
* `ppc_error_scatter_avg_vs_x(x = some_expression)` labels the *x* axis with `some_expression`.
5+
36
# bayesplot 1.13.0
47

58
* Add `ppc_loo_pit_ecdf()` by @TeemuSailynoja (#345)

R/bayesplot-helpers.R

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,53 @@ grid_lines_y <- function(color = "gray50", size = 0.2) {
469469
overlay_function <- function(...) {
470470
stat_function(..., inherit.aes = FALSE)
471471
}
472+
473+
474+
475+
# Resolve a function name and store the expression passed in by the user
476+
#' @noRd
477+
#' @param f a function-like thing: a string naming a function, a function
478+
#' object, an anonymous function object, a formula-based lambda, and `NULL`.
479+
#' @param fallback character string providing a fallback function name
480+
#' @return the function named in `f` with an added `"tagged_expr"` attribute
481+
#' containing the expression to represent the function name and an
482+
#' `"is_anonymous_function"` attribute to flag if the expression is a call to
483+
#' `function()`.
484+
as_tagged_function <- function(f = NULL, fallback = "func") {
485+
qf <- enquo(f)
486+
f <- eval_tidy(qf)
487+
if (!is.null(attr(f, "tagged_expr"))) return(f)
488+
489+
f_expr <- quo_get_expr(qf)
490+
f_fn <- f
491+
492+
if (is_character(f)) { # f = "mean"
493+
# using sym() on the evaluated `f` means that a variable that names a
494+
# function string `x <- "mean"; as_tagged_function(x)` will be lost
495+
# but that seems okay
496+
f_expr <- sym(f)
497+
f_fn <- match.fun(f)
498+
} else if (is_null(f)) { # f = NULL
499+
f_fn <- identity
500+
f_expr <- sym(fallback)
501+
} else if (is_callable(f)) { # f = mean or f = function(x) mean(x)
502+
f_expr <- f_expr # or f = ~mean(.x)
503+
f_fn <- as_function(f)
504+
}
505+
506+
# Setting attributes on primitive functions is deprecated, so wrap them
507+
# and then tag
508+
if (is_primitive(f_fn)) {
509+
f_fn_old <- f_fn
510+
f_factory <- function(f) { function(...) f(...) }
511+
f_fn <- f_factory(f_fn_old)
512+
}
513+
514+
attr(f_fn, "tagged_expr") <- f_expr
515+
attr(f_fn, "is_anonymous_function") <-
516+
is_call(f_expr, name = "function") || is_formula(f_expr)
517+
f_fn
518+
}
519+
520+
521+

R/bayesplot-package.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#' **bayesplot**: Plotting for Bayesian Models
22
#'
3-
#' @docType package
43
#' @name bayesplot-package
54
#' @aliases bayesplot
65
#'
@@ -96,7 +95,7 @@
9695
#' ppd_hist(ypred[1:8, ])
9796
#' }
9897
#'
99-
NULL
98+
"_PACKAGE"
10099

101100

102101
# internal ----------------------------------------------------------------

R/ppc-errors.R

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#' @template args-group
1111
#' @template args-facet_args
1212
#' @param ... Currently unused.
13+
#' @param stat A function or a string naming a function for computing the
14+
#' posterior average. In both cases, the function should take a vector input and
15+
#' return a scalar statistic. The function name is displayed in the axis-label.
16+
#' Defaults to `"mean"`.
1317
#' @param size,alpha For scatterplots, arguments passed to
1418
#' [ggplot2::geom_point()] to control the appearance of the points. For the
1519
#' binned error plot, arguments controlling the size of the outline and
@@ -209,21 +213,26 @@ ppc_error_scatter_avg <-
209213
function(y,
210214
yrep,
211215
...,
216+
stat = "mean",
212217
size = 2.5,
213218
alpha = 0.8) {
214219
check_ignored_arguments(...)
215220

216221
y <- validate_y(y)
217222
yrep <- validate_predictions(yrep, length(y))
218223
errors <- compute_errors(y, yrep)
224+
225+
stat <- as_tagged_function({{ stat }})
226+
219227
ppc_scatter_avg(
220228
y = y,
221229
yrep = errors,
222230
size = size,
223231
alpha = alpha,
224-
ref_line = FALSE
232+
ref_line = FALSE,
233+
stat = stat
225234
) +
226-
labs(x = error_avg_label(), y = y_label())
235+
labs(x = error_avg_label(stat), y = y_label())
227236
}
228237

229238

@@ -234,13 +243,16 @@ ppc_error_scatter_avg_grouped <-
234243
yrep,
235244
group,
236245
...,
246+
stat = "mean",
237247
facet_args = list(),
238248
size = 2.5,
239249
alpha = 0.8) {
240250
check_ignored_arguments(...)
241251

242252
y <- validate_y(y)
243253
yrep <- validate_predictions(yrep, length(y))
254+
stat <- as_tagged_function({{ stat }})
255+
244256
errors <- compute_errors(y, yrep)
245257
ppc_scatter_avg_grouped(
246258
y = y,
@@ -249,9 +261,10 @@ ppc_error_scatter_avg_grouped <-
249261
size = size,
250262
alpha = alpha,
251263
facet_args = facet_args,
252-
ref_line = FALSE
264+
ref_line = FALSE,
265+
stat = stat
253266
) +
254-
labs(x = error_avg_label(), y = y_label())
267+
labs(x = error_avg_label(stat), y = y_label())
255268
}
256269

257270

@@ -260,29 +273,37 @@ ppc_error_scatter_avg_grouped <-
260273
#' @param x A numeric vector the same length as `y` to use as the x-axis
261274
#' variable.
262275
#'
263-
ppc_error_scatter_avg_vs_x <-
264-
function(y,
265-
yrep,
266-
x,
267-
...,
268-
size = 2.5,
269-
alpha = 0.8) {
270-
check_ignored_arguments(...)
276+
ppc_error_scatter_avg_vs_x <- function(
277+
y,
278+
yrep,
279+
x,
280+
...,
281+
stat = "mean",
282+
size = 2.5,
283+
alpha = 0.8
284+
) {
285+
check_ignored_arguments(...)
271286

272-
y <- validate_y(y)
273-
yrep <- validate_predictions(yrep, length(y))
274-
x <- validate_x(x, y)
275-
errors <- compute_errors(y, yrep)
276-
ppc_scatter_avg(
277-
y = x,
278-
yrep = errors,
279-
size = size,
280-
alpha = alpha,
281-
ref_line = FALSE
287+
y <- validate_y(y)
288+
yrep <- validate_predictions(yrep, length(y))
289+
qx <- enquo(x)
290+
x <- validate_x(x, y)
291+
stat <- as_tagged_function({{ stat }})
292+
errors <- compute_errors(y, yrep)
293+
ppc_scatter_avg(
294+
y = x,
295+
yrep = errors,
296+
size = size,
297+
alpha = alpha,
298+
ref_line = FALSE,
299+
stat = stat
300+
) +
301+
labs(
302+
x = error_avg_label(stat),
303+
y = as_label((qx))
282304
) +
283-
labs(x = error_avg_label(), y = expression(italic(x))) +
284-
coord_flip()
285-
}
305+
coord_flip()
306+
}
286307

287308

288309
#' @rdname PPC-errors
@@ -414,8 +435,21 @@ error_hist_facets <-
414435
error_label <- function() {
415436
expression(italic(y) - italic(y)[rep])
416437
}
417-
error_avg_label <- function() {
418-
expression(paste("Average ", italic(y) - italic(y)[rep]))
438+
439+
error_avg_label <- function(stat = NULL) {
440+
stat <- as_tagged_function({{ stat }}, fallback = "stat")
441+
e <- attr(stat, "tagged_expr")
442+
if (attr(stat, "is_anonymous_function")) {
443+
e <- sym("stat")
444+
}
445+
de <- deparse1(e)
446+
447+
# create some dummy variables to pass the R package check for
448+
# global variables in the expression below
449+
italic <- sym("italic")
450+
y <- sym("y")
451+
452+
expr(paste((!!de))*(italic(y) - italic(y)[rep]))
419453
}
420454

421455

R/ppc-scatterplots.R

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
#' @template args-group
1212
#' @template args-facet_args
1313
#' @param ... Currently unused.
14+
#' @param stat A function or a string naming a function for computing the
15+
#' posterior average. In both cases, the function should take a vector input
16+
#' and return a scalar statistic. The function name is displayed in the
17+
#' axis-label, and the underlying `$rep_label` for `ppc_scatter_avg_data()`
18+
#' includes the function name. Defaults to `"mean"`.
1419
#' @param size,alpha Arguments passed to [ggplot2::geom_point()] to control the
1520
#' appearance of the points.
1621
#' @param ref_line If `TRUE` (the default) a dashed line with intercept 0 and
@@ -31,10 +36,10 @@
3136
#' }
3237
#' \item{`ppc_scatter_avg()`}{
3338
#' A single scatterplot of `y` against the average values of `yrep`, i.e.,
34-
#' the points `(x,y) = (mean(yrep[, n]), y[n])`, where each `yrep[, n]` is
35-
#' a vector of length equal to the number of posterior draws. Unlike
36-
#' for `ppc_scatter()`, for `ppc_scatter_avg()` `yrep` should contain many
37-
#' draws (rows).
39+
#' the points `(x,y) = (average(yrep[, n]), y[n])`, where each `yrep[, n]` is
40+
#' a vector of length equal to the number of posterior draws and `average()`
41+
#' is a summary statistic. Unlike for `ppc_scatter()`, for
42+
#' `ppc_scatter_avg()` `yrep` should contain many draws (rows).
3843
#' }
3944
#' \item{`ppc_scatter_avg_grouped()`}{
4045
#' The same as `ppc_scatter_avg()`, but a separate plot is generated for
@@ -59,6 +64,9 @@
5964
#' p1 + lims
6065
#' p2 + lims
6166
#'
67+
#' # "average" function is customizable
68+
#' ppc_scatter_avg(y, yrep, stat = "median", ref_line = FALSE)
69+
#'
6270
#' # for ppc_scatter_avg_grouped the default is to allow the facets
6371
#' # to have different x and y axes
6472
#' group <- example_group_data()
@@ -116,16 +124,19 @@ ppc_scatter_avg <-
116124
function(y,
117125
yrep,
118126
...,
127+
stat = "mean",
119128
size = 2.5,
120129
alpha = 0.8,
121130
ref_line = TRUE) {
122131
dots <- list(...)
132+
stat <- as_tagged_function({{ stat }})
133+
123134
if (!from_grouped(dots)) {
124135
check_ignored_arguments(...)
125136
dots$group <- NULL
126137
}
127138

128-
data <- ppc_scatter_avg_data(y, yrep, group = dots$group)
139+
data <- ppc_scatter_avg_data(y, yrep, group = dots$group, stat = stat)
129140
if (is.null(dots$group) && nrow(yrep) == 1) {
130141
inform(
131142
"With only 1 row in 'yrep' ppc_scatter_avg is the same as ppc_scatter."
@@ -143,7 +154,7 @@ ppc_scatter_avg <-
143154
# ppd instead of ppc (see comment in ppc_scatter)
144155
scale_color_ppd() +
145156
scale_fill_ppd() +
146-
labs(x = yrep_avg_label(), y = y_label()) +
157+
labs(x = yrep_avg_label(stat), y = y_label()) +
147158
bayesplot_theme_get()
148159
}
149160

@@ -155,6 +166,7 @@ ppc_scatter_avg_grouped <-
155166
yrep,
156167
group,
157168
...,
169+
stat = "mean",
158170
facet_args = list(),
159171
size = 2.5,
160172
alpha = 0.8,
@@ -184,16 +196,19 @@ ppc_scatter_data <- function(y, yrep) {
184196

185197
#' @rdname PPC-scatterplots
186198
#' @export
187-
ppc_scatter_avg_data <- function(y, yrep, group = NULL) {
199+
ppc_scatter_avg_data <- function(y, yrep, group = NULL, stat = "mean") {
188200
y <- validate_y(y)
189201
yrep <- validate_predictions(yrep, length(y))
190202
if (!is.null(group)) {
191203
group <- validate_group(group, length(y))
192204
}
205+
stat <- as_tagged_function({{ stat }})
193206

194-
data <- ppc_scatter_data(y = y, yrep = t(colMeans(yrep)))
207+
data <- ppc_scatter_data(y = y, yrep = t(apply(yrep, 2, FUN = stat)))
195208
data$rep_id <- NA_integer_
196-
levels(data$rep_label) <- "mean(italic(y)[rep]))"
209+
levels(data$rep_label) <- yrep_avg_label(stat) |>
210+
as.expression() |>
211+
as.character()
197212

198213
if (!is.null(group)) {
199214
data <- tibble::add_column(data,
@@ -206,7 +221,22 @@ ppc_scatter_avg_data <- function(y, yrep, group = NULL) {
206221
}
207222

208223
# internal ----------------------------------------------------------------
209-
yrep_avg_label <- function() expression(paste("Average ", italic(y)[rep]))
224+
225+
yrep_avg_label <- function(stat = NULL) {
226+
stat <- as_tagged_function({{ stat }}, fallback = "stat")
227+
e <- attr(stat, "tagged_expr")
228+
if (attr(stat, "is_anonymous_function")) {
229+
e <- sym("stat")
230+
}
231+
de <- deparse1(e)
232+
233+
# create some dummy variables to pass the R package check for
234+
# global variables in the expression below
235+
italic <- sym("italic")
236+
y <- sym("y")
237+
238+
expr(paste((!!de))*(italic(y)[rep]))
239+
}
210240

211241
scatter_aes <- function(...) {
212242
aes(x = .data$value, y = .data$y_obs, ...)

0 commit comments

Comments
 (0)