Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: collect_race() or collect_metrics() for the race #99

Closed
jrosell opened this issue Jan 18, 2024 · 2 comments
Closed

Feature: collect_race() or collect_metrics() for the race #99

jrosell opened this issue Jan 18, 2024 · 2 comments

Comments

@jrosell
Copy link

jrosell commented Jan 18, 2024

Feature

In situations when one wants to analyze the intermediate results of a race, one shouldn't be required to know the internal data structure of the tune package and be able to use some function like collect_race or similar.

Here's an example of what we get with the current functions and what I expect to get.

project_name <- "sliced-s01e09-playoffs-1"
output_dir <- here::here(project_name, "data")
dir.create(file.path(output_dir), showWarnings = FALSE, recursive = TRUE)
kaggler::kgl_competitions_data_download_all(project_name, output_dir = output_dir)
library(tidyverse)
library(tidymodels)
library(finetune)
options(readr.show_col_types = FALSE)
theme_set(theme_light())

train_raw <- read_csv(here::here(output_dir, "train.csv"))

set.seed(123)
bb_split <- train_raw %>%
    mutate(
        is_home_run = if_else(as.logical(is_home_run), "HR", "no"),
        is_home_run = factor(is_home_run)
    ) %>%
    na.omit() %>% 
    sample_n(5000) %>% 
    initial_split(strata = is_home_run)
bb_train <- training(bb_split)
bb_test <- testing(bb_split)
set.seed(234)
bb_folds <- vfold_cv(bb_train, strata = is_home_run, v = 10)

bb_rec <-
    recipe(is_home_run ~ launch_angle + launch_speed + plate_x + plate_z +
               bb_type + bearing + pitch_mph +
               is_pitcher_lefty + is_batter_lefty +
               inning + balls + strikes + game_date,
           data = bb_train
    ) %>%
    step_date(game_date, features = c("week"), keep_original_cols = FALSE) %>%
    step_unknown(all_nominal_predictors()) %>%
    step_dummy(all_nominal_predictors())

xgb_wf <- bb_rec %>% 
    workflow(
        boost_tree(
            mode = "classification",
            trees = tune(),
            min_n = tune(),
            mtry = tune(),
            learn_rate = tune(),
            tree_depth = tune(),
            loss_reduction = tune(),
            sample_size = tune()
        ) %>%
        set_engine("xgboost", counts = FALSE)
    )

set.seed(123)
xgb_grid <- xgb_wf %>% 
    extract_parameter_set_dials() %>% 
    update(
        trees = trees(c(100, 100)),
        min_n = min_n(c(1, 300)),
        mtry = mtry_prop(c(0.1, 0.4)),
        learn_rate = learn_rate(c(0.3, 0.3)),
        tree_depth = tree_depth(c(2, 6)),
        loss_reduction = loss_reduction(c(0, 0), trans = NULL),
        sample_size = sample_prop(c(0.4, 0.8))
    ) %>% 
    grid_max_entropy(size = 10)

cores <- parallelly::availableCores(omit = 15)
if(cores > 1) {
    print(paste("Using", cores, "cores"))
    doParallel::registerDoParallel(cores)
}
#> [1] "Using 5 cores"
set.seed(345)
xgb_rs <- tune_race_anova(
    xgb_wf,
    resamples = bb_folds,
    grid = xgb_grid,
    metrics = metric_set(mn_log_loss),
    control = control_race(verbose_elim = TRUE)
)
#> ℹ Racing will minimize the mn_log_loss metric.
#> ℹ Resamples are analyzed in a random order.
#> ℹ Fold10: 8 eliminated; 2 candidates remain.
#> 
#> ℹ Fold07: All but one parameter combination were eliminated.
if(cores > 1) {
    doParallel::stopImplicitCluster()
}
    
xgb_rs %>% show_best(metric = "mn_log_loss")
#> # A tibble: 1 × 13
#>    mtry trees min_n tree_depth learn_rate loss_reduction sample_size .metric    
#>   <dbl> <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>      
#> 1 0.277   100    47          5       2.00              0       0.632 mn_log_loss
#> # ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
#> #   .config <chr>
# # A tibble: 1 × 13
# mtry trees min_n tree_depth learn_rate loss_reduction sample_size .metric     .estimator  mean     n std_err .config              
# <dbl> <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>       <chr>      <dbl> <int>   <dbl> <chr>                
# 1 0.277   100    47          5       2.00              0       0.632 mn_log_loss binary     0.107    10 0.00630 Preprocessor1_Model06

xgb_rs %>% collect_metrics(summarize = FALSE) %>% arrange(.estimate)
#> # A tibble: 10 × 12
#>    id      mtry trees min_n tree_depth learn_rate loss_reduction sample_size
#>    <chr>  <dbl> <int> <int>      <int>      <dbl>          <dbl>       <dbl>
#>  1 Fold07 0.277   100    47          5       2.00              0       0.632
#>  2 Fold04 0.277   100    47          5       2.00              0       0.632
#>  3 Fold01 0.277   100    47          5       2.00              0       0.632
#>  4 Fold08 0.277   100    47          5       2.00              0       0.632
#>  5 Fold10 0.277   100    47          5       2.00              0       0.632
#>  6 Fold09 0.277   100    47          5       2.00              0       0.632
#>  7 Fold05 0.277   100    47          5       2.00              0       0.632
#>  8 Fold02 0.277   100    47          5       2.00              0       0.632
#>  9 Fold06 0.277   100    47          5       2.00              0       0.632
#> 10 Fold03 0.277   100    47          5       2.00              0       0.632
#> # ℹ 4 more variables: .metric <chr>, .estimator <chr>, .estimate <dbl>,
#> #   .config <chr>

xgb_rs %>%
    dplyr::select(id, .order, .metrics) %>%
    tidyr::unnest(cols = .metrics) %>% 
    dplyr::group_by(!!!rlang::syms(attributes(xgb_rs)$parameters$id), .metric, .estimator) %>%
    dplyr::summarize(
        mean = mean(.estimate, na.rm = TRUE),
        n = sum(!is.na(.estimate)),
        std_err = sd(.estimate, na.rm = TRUE) / sqrt(n),
        .groups = "drop"
    ) %>% 
    arrange(mean) %>% 
    print(n = Inf)
#> # A tibble: 10 × 12
#>     mtry trees min_n tree_depth learn_rate loss_reduction sample_size .metric   
#>    <dbl> <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>     
#>  1 0.277   100    47          5       2.00              0       0.632 mn_log_lo…
#>  2 0.310   100    68          4       2.00              0       0.754 mn_log_lo…
#>  3 0.256   100   107          2       2.00              0       0.527 mn_log_lo…
#>  4 0.353   100    63          3       2.00              0       0.540 mn_log_lo…
#>  5 0.343   100    67          5       2.00              0       0.615 mn_log_lo…
#>  6 0.304   100   264          4       2.00              0       0.751 mn_log_lo…
#>  7 0.207   100   158          4       2.00              0       0.418 mn_log_lo…
#>  8 0.115   100   120          4       2.00              0       0.629 mn_log_lo…
#>  9 0.198   100    98          2       2.00              0       0.534 mn_log_lo…
#> 10 0.137   100   209          4       2.00              0       0.725 mn_log_lo…
#> # ℹ 4 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>

Created on 2024-01-18 with reprex v2.1.0.9000

@jrosell jrosell changed the title Feature: collect_race() Feature: collect_race() or collect_metrics() for the race Jan 18, 2024
@simonpcouch
Copy link
Contributor

Thanks for the issue, @jrosell!

It seems like the all_configs argument to the tune_race collect_metrics() method might be helpful for you!

library(tidymodels)
library(finetune)
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness

data(two_class_dat, package = "modeldata")

set.seed(6376)
rs <- bootstraps(two_class_dat, times = 10)

# optimize an regularized discriminant analysis model
rda_spec <-
  discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>%
  set_engine("klaR")

ctrl <- control_race(verbose_elim = TRUE)
set.seed(11)
grid_anova <-
  rda_spec %>%
  tune_race_anova(Class ~ ., resamples = rs, grid = 10, control = ctrl)
#> ℹ Racing will maximize the roc_auc metric.
#> ℹ Resamples are analyzed in a random order.
#> ℹ Bootstrap05: All but one parameter combination were eliminated.

plot_race(grid_anova)

A quick visual of the racing process:

This is reasonably well-reflected in collect_metrics() output:

collect_metrics(grid_anova, all_configs = TRUE)
#> # A tibble: 20 × 8
#>    frac_common_cov frac_identity .metric  .estimator  mean     n std_err .config
#>              <dbl>         <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>  
#>  1          0.0691        0.0437 accuracy binary     0.811    10 0.00578 Prepro…
#>  2          0.0691        0.0437 roc_auc  binary     0.886    10 0.00513 Prepro…
#>  3          0.199         0.595  accuracy binary     0.733     3 0.0139  Prepro…
#>  4          0.199         0.595  roc_auc  binary     0.825     3 0.00535 Prepro…
#>  5          0.962         0.716  accuracy binary     0.719     3 0.0118  Prepro…
#>  6          0.962         0.716  roc_auc  binary     0.814     3 0.00526 Prepro…
#>  7          0.271         0.910  accuracy binary     0.709     3 0.0130  Prepro…
#>  8          0.271         0.910  roc_auc  binary     0.798     3 0.00501 Prepro…
#>  9          0.781         0.666  accuracy binary     0.726     3 0.0126  Prepro…
#> 10          0.781         0.666  roc_auc  binary     0.818     3 0.00529 Prepro…
#> 11          0.481         0.453  accuracy binary     0.751     3 0.0117  Prepro…
#> 12          0.481         0.453  roc_auc  binary     0.839     3 0.00524 Prepro…
#> 13          0.837         0.824  accuracy binary     0.711     3 0.0142  Prepro…
#> 14          0.837         0.824  roc_auc  binary     0.805     3 0.00519 Prepro…
#> 15          0.605         0.385  accuracy binary     0.754     3 0.0117  Prepro…
#> 16          0.605         0.385  roc_auc  binary     0.846     3 0.00522 Prepro…
#> 17          0.555         0.293  accuracy binary     0.774     3 0.0159  Prepro…
#> 18          0.555         0.293  roc_auc  binary     0.856     3 0.00485 Prepro…
#> 19          0.392         0.154  accuracy binary     0.790     3 0.0133  Prepro…
#> 20          0.392         0.154  roc_auc  binary     0.871     3 0.00459 Prepro…

Created on 2024-01-18 with reprex v2.1.0

Note the n column, specifically. summarize = FALSE would give the performance metrics for each configuration by resample, including those that weren't resampled fully if all_configs = TRUE.

@jrosell
Copy link
Author

jrosell commented Jan 19, 2024

My fault. I was checking tune docs instead of finetune docs. https://finetune.tidymodels.org/reference/collect_predictions.html

@jrosell jrosell closed this as completed Jan 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants