From 34b673faa48492399fce1a717c9445c8b876eeaa Mon Sep 17 00:00:00 2001 From: "Simon P. Couch" Date: Wed, 28 Feb 2024 10:20:09 -0600 Subject: [PATCH] support `eval_time` in `fit_best()` and `rank_results()` (#144) --- DESCRIPTION | 4 +++- NEWS.md | 4 +++- R/fit_best.R | 35 +++++++++++++++++++++++++------ R/rank_results.R | 17 +++++++++++++-- man/fit_best.workflow_set.Rd | 7 ++++++- man/rank_results.Rd | 7 ++++++- tests/testthat/_snaps/fit_best.md | 4 ++-- 7 files changed, 64 insertions(+), 14 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index be11820..57792c9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: workflowsets Title: Create a Collection of 'tidymodels' Workflows -Version: 1.0.1.9001 +Version: 1.0.1.9002 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut"), comment = c(ORCID = "0000-0003-2402-136X")), @@ -54,6 +54,8 @@ Suggests: testthat (>= 3.0.0), tidyclust, yardstick (>= 1.3.0) +Remotes: + tidymodels/tune VignetteBuilder: knitr Config/Needs/website: diff --git a/NEWS.md b/NEWS.md index 2c90b9d..9a862e3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,8 @@ # workflowsets (development version) -* Enabled evaluating censored regression models (#139). +* Enabled evaluating censored regression models (#139, #144). The placement of + the new `eval_time` argument to `rank_results()` breaks passing-by-position + for the `select_best` argument. * Added a `collect_notes()` method for workflow sets (#135). * Added methods to improve error messages when workflow sets are mistakenly passed to unsupported functions like `fit()` and `predict()` (#137). diff --git a/R/fit_best.R b/R/fit_best.R index 55e6b6a..ab32815 100644 --- a/R/fit_best.R +++ b/R/fit_best.R @@ -11,6 +11,7 @@ tune::fit_best #' with [workflow_map()]. Note that the workflow set must have been fitted with #' the [control option][option_add] `save_workflow = TRUE`. #' @param metric A character string giving the metric to rank results by. +#' @inheritParams tune::fit_best.tune_results #' @param ... Additional options to pass to #' [tune::fit_best][tune::fit_best.tune_results]. #' @@ -71,18 +72,40 @@ tune::fit_best #' fit_best(chi_features_res_new, metric = "rmse") #' @name fit_best.workflow_set #' @export -fit_best.workflow_set <- function(x, metric = NULL, ...) { +fit_best.workflow_set <- function(x, metric = NULL, eval_time = NULL, ...) { check_string(metric, allow_null = TRUE) + result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]]) + met_set <- tune::.get_tune_metrics(result_1) + if (is.null(metric)) { - result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]]) - metric <- .get_tune_metric_names(result_1)[1] + metric <- .get_tune_metric_names(result_1)[1] + } else { + tune::check_metric_in_tune_results(tibble::as_tibble(met_set), metric) + } + + if (is.null(eval_time) & is_dyn(met_set, metric)) { + eval_time <- tune::.get_tune_eval_times(result_1)[1] } - rankings <- rank_results(x, rank_metric = metric, select_best = TRUE) + rankings <- + rank_results( + x, + rank_metric = metric, + select_best = TRUE, + eval_time = eval_time + ) tune_res <- extract_workflow_set_result(x, id = rankings$wflow_id[1]) - best_params <- select_best(tune_res, metric = metric) + best_params <- select_best(tune_res, metric = metric, eval_time = eval_time) + + fit_best(tune_res, parameters = best_params, ...) +} - fit_best(tune_res, metric = metric, parameters = best_params, ...) +# from unexported +# https://github.com/tidymodels/tune/blob/5b0e10fac559f18c075eb4bd7020e217c6174e66/R/metric-selection.R#L137-L141 +is_dyn <- function(mtr_set, metric) { + mtr_info <- tibble::as_tibble(mtr_set) + mtr_cls <- mtr_info$class[mtr_info$metric == metric] + mtr_cls == "dynamic_survival_metric" } diff --git a/R/rank_results.R b/R/rank_results.R index e7228ec..a7912c4 100644 --- a/R/rank_results.R +++ b/R/rank_results.R @@ -4,6 +4,7 @@ #' #' @inheritParams collect_metrics.workflow_set #' @param rank_metric A character string for a metric. +#' @inheritParams tune::fit_best.tune_results #' @param select_best A logical giving whether the results should only contain #' the numerically best submodel per workflow. #' @details @@ -26,10 +27,17 @@ #' rank_results(chi_features_res, select_best = TRUE) #' rank_results(chi_features_res, rank_metric = "rsq") #' @export -rank_results <- function(x, rank_metric = NULL, select_best = FALSE) { +rank_results <- function(x, rank_metric = NULL, eval_time = NULL, select_best = FALSE) { check_wf_set(x) check_string(rank_metric, allow_null = TRUE) check_bool(select_best) + result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]]) + met_set <- tune::.get_tune_metrics(result_1) + if (!is.null(rank_metric)) { + tune::check_metric_in_tune_results(tibble::as_tibble(met_set), rank_metric) + } + + eval_time <- tune::choose_eval_time(result_1, rank_metric, eval_time) metric_info <- pick_metric(x, rank_metric) metric <- metric_info$metric @@ -37,10 +45,15 @@ rank_results <- function(x, rank_metric = NULL, select_best = FALSE) { wflow_info <- dplyr::bind_cols(purrr::map_dfr(x$info, I), dplyr::select(x, wflow_id)) results <- collect_metrics(x) %>% - dplyr::select(wflow_id, .config, .metric, mean, std_err, n) %>% + dplyr::select(wflow_id, .config, .metric, mean, std_err, n, + dplyr::any_of(".eval_time")) %>% dplyr::full_join(wflow_info, by = "wflow_id") %>% dplyr::select(-comment, -workflow) + if (".eval_time" %in% names(results)) { + results <- results[results$.eval_time == eval_time, ] + } + types <- x %>% dplyr::full_join(wflow_info, by = "wflow_id") %>% dplyr::mutate( diff --git a/man/fit_best.workflow_set.Rd b/man/fit_best.workflow_set.Rd index 7b39143..7a16cc1 100644 --- a/man/fit_best.workflow_set.Rd +++ b/man/fit_best.workflow_set.Rd @@ -4,7 +4,7 @@ \alias{fit_best.workflow_set} \title{Fit a model to the numerically optimal configuration} \usage{ -\method{fit_best}{workflow_set}(x, metric = NULL, ...) +\method{fit_best}{workflow_set}(x, metric = NULL, eval_time = NULL, ...) } \arguments{ \item{x}{A \code{\link[=workflow_set]{workflow_set}} object that has been evaluated @@ -13,6 +13,11 @@ the \link[=option_add]{control option} \code{save_workflow = TRUE}.} \item{metric}{A character string giving the metric to rank results by.} +\item{eval_time}{A single numeric time point where dynamic event time +metrics should be chosen (e.g., the time-dependent ROC curve, etc). The +values should be consistent with the values used to create \code{x}. The \code{NULL} +default will automatically use the first evaluation time used by \code{x}.} + \item{...}{Additional options to pass to \link[tune:fit_best]{tune::fit_best}.} } diff --git a/man/rank_results.Rd b/man/rank_results.Rd index 0416656..1631f9b 100644 --- a/man/rank_results.Rd +++ b/man/rank_results.Rd @@ -4,7 +4,7 @@ \alias{rank_results} \title{Rank the results by a metric} \usage{ -rank_results(x, rank_metric = NULL, select_best = FALSE) +rank_results(x, rank_metric = NULL, eval_time = NULL, select_best = FALSE) } \arguments{ \item{x}{A \code{\link[=workflow_set]{workflow_set}} object that has been evaluated @@ -12,6 +12,11 @@ with \code{\link[=workflow_map]{workflow_map()}}.} \item{rank_metric}{A character string for a metric.} +\item{eval_time}{A single numeric time point where dynamic event time +metrics should be chosen (e.g., the time-dependent ROC curve, etc). The +values should be consistent with the values used to create \code{x}. The \code{NULL} +default will automatically use the first evaluation time used by \code{x}.} + \item{select_best}{A logical giving whether the results should only contain the numerically best submodel per workflow.} } diff --git a/tests/testthat/_snaps/fit_best.md b/tests/testthat/_snaps/fit_best.md index 948da08..aad52ef 100644 --- a/tests/testthat/_snaps/fit_best.md +++ b/tests/testthat/_snaps/fit_best.md @@ -11,8 +11,8 @@ Code fit_best(chi_features_map, metric = "boop") Condition - Error in `halt()`: - ! Metric 'boop' was not in the results. + Error in `fit_best()`: + ! "boop" was not in the metric set. Please choose from: "rmse" and "rsq". ---