Skip to content

Commit

Permalink
support eval_time in fit_best() and rank_results() (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Feb 28, 2024
1 parent 8013f0a commit 34b673f
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 14 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: workflowsets
Title: Create a Collection of 'tidymodels' Workflows
Version: 1.0.1.9001
Version: 1.0.1.9002
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down Expand Up @@ -54,6 +54,8 @@ Suggests:
testthat (>= 3.0.0),
tidyclust,
yardstick (>= 1.3.0)
Remotes:
tidymodels/tune
VignetteBuilder:
knitr
Config/Needs/website:
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# workflowsets (development version)

* Enabled evaluating censored regression models (#139).
* Enabled evaluating censored regression models (#139, #144). The placement of
the new `eval_time` argument to `rank_results()` breaks passing-by-position
for the `select_best` argument.
* Added a `collect_notes()` method for workflow sets (#135).
* Added methods to improve error messages when workflow sets are mistakenly
passed to unsupported functions like `fit()` and `predict()` (#137).
Expand Down
35 changes: 29 additions & 6 deletions R/fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tune::fit_best
#' with [workflow_map()]. Note that the workflow set must have been fitted with
#' the [control option][option_add] `save_workflow = TRUE`.
#' @param metric A character string giving the metric to rank results by.
#' @inheritParams tune::fit_best.tune_results
#' @param ... Additional options to pass to
#' [tune::fit_best][tune::fit_best.tune_results].
#'
Expand Down Expand Up @@ -71,18 +72,40 @@ tune::fit_best
#' fit_best(chi_features_res_new, metric = "rmse")
#' @name fit_best.workflow_set
#' @export
fit_best.workflow_set <- function(x, metric = NULL, ...) {
fit_best.workflow_set <- function(x, metric = NULL, eval_time = NULL, ...) {
check_string(metric, allow_null = TRUE)
result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]])
met_set <- tune::.get_tune_metrics(result_1)

if (is.null(metric)) {
result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]])
metric <- .get_tune_metric_names(result_1)[1]
metric <- .get_tune_metric_names(result_1)[1]
} else {
tune::check_metric_in_tune_results(tibble::as_tibble(met_set), metric)
}

if (is.null(eval_time) & is_dyn(met_set, metric)) {
eval_time <- tune::.get_tune_eval_times(result_1)[1]
}

rankings <- rank_results(x, rank_metric = metric, select_best = TRUE)
rankings <-
rank_results(
x,
rank_metric = metric,
select_best = TRUE,
eval_time = eval_time
)

tune_res <- extract_workflow_set_result(x, id = rankings$wflow_id[1])

best_params <- select_best(tune_res, metric = metric)
best_params <- select_best(tune_res, metric = metric, eval_time = eval_time)

fit_best(tune_res, parameters = best_params, ...)
}

fit_best(tune_res, metric = metric, parameters = best_params, ...)
# from unexported
# https://github.com/tidymodels/tune/blob/5b0e10fac559f18c075eb4bd7020e217c6174e66/R/metric-selection.R#L137-L141
is_dyn <- function(mtr_set, metric) {
mtr_info <- tibble::as_tibble(mtr_set)
mtr_cls <- mtr_info$class[mtr_info$metric == metric]
mtr_cls == "dynamic_survival_metric"
}
17 changes: 15 additions & 2 deletions R/rank_results.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#'
#' @inheritParams collect_metrics.workflow_set
#' @param rank_metric A character string for a metric.
#' @inheritParams tune::fit_best.tune_results
#' @param select_best A logical giving whether the results should only contain
#' the numerically best submodel per workflow.
#' @details
Expand All @@ -26,21 +27,33 @@
#' rank_results(chi_features_res, select_best = TRUE)
#' rank_results(chi_features_res, rank_metric = "rsq")
#' @export
rank_results <- function(x, rank_metric = NULL, select_best = FALSE) {
rank_results <- function(x, rank_metric = NULL, eval_time = NULL, select_best = FALSE) {
check_wf_set(x)
check_string(rank_metric, allow_null = TRUE)
check_bool(select_best)
result_1 <- extract_workflow_set_result(x, id = x$wflow_id[[1]])
met_set <- tune::.get_tune_metrics(result_1)
if (!is.null(rank_metric)) {
tune::check_metric_in_tune_results(tibble::as_tibble(met_set), rank_metric)
}

eval_time <- tune::choose_eval_time(result_1, rank_metric, eval_time)

metric_info <- pick_metric(x, rank_metric)
metric <- metric_info$metric
direction <- metric_info$direction
wflow_info <- dplyr::bind_cols(purrr::map_dfr(x$info, I), dplyr::select(x, wflow_id))

results <- collect_metrics(x) %>%
dplyr::select(wflow_id, .config, .metric, mean, std_err, n) %>%
dplyr::select(wflow_id, .config, .metric, mean, std_err, n,
dplyr::any_of(".eval_time")) %>%
dplyr::full_join(wflow_info, by = "wflow_id") %>%
dplyr::select(-comment, -workflow)

if (".eval_time" %in% names(results)) {
results <- results[results$.eval_time == eval_time, ]
}

types <- x %>%
dplyr::full_join(wflow_info, by = "wflow_id") %>%
dplyr::mutate(
Expand Down
7 changes: 6 additions & 1 deletion man/fit_best.workflow_set.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/rank_results.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/_snaps/fit_best.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
Code
fit_best(chi_features_map, metric = "boop")
Condition
Error in `halt()`:
! Metric 'boop' was not in the results.
Error in `fit_best()`:
! "boop" was not in the metric set. Please choose from: "rmse" and "rsq".

---

Expand Down

0 comments on commit 34b673f

Please sign in to comment.