-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
show_best tests for censored regression models #156
Merged
Merged
Changes from 5 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
19e963e
initial tests for #118
64c6f19
return 5 configs
15081f2
missing %>%
18c4bce
reduce changes of ties affecting results
c147241
trying to avoid irreproducible results across runs and OS'es
30b3916
Apply suggestions from code review
topepo 69043ac
update for latest tune version
ce40180
Merge branch 'main' into survival-show-best
topepo 6e0e77d
updated snapshots
topepo 4fdbeeb
updated snapshots with new current package versions
topepo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# show_best with censored data - integrated metric - grid | ||
|
||
No value of `metric` was given; "brier_survival_integrated" will be used. | ||
|
||
# show_best with censored data - dynamic metric - bayes | ||
|
||
No value of `metric` was given; "brier_survival" will be used. | ||
|
||
--- | ||
|
||
Code | ||
show_best(bayes_dyn_res, metric = "brier_survival", eval_time = 1) | ||
Condition | ||
Error in `show_best()`: | ||
! Evaluation time 1 is not in the results. | ||
|
||
--- | ||
|
||
Code | ||
show_best(bayes_dyn_res, metric = "brier_survival_integrated") | ||
Condition | ||
Error in `show_best()`: | ||
! "brier_survival_integrated" was not in the metric set. Please choose from: "brier_survival". | ||
|
||
# show_best with censored data - static metric - anova racing | ||
|
||
No value of `metric` was given; "concordance_survival" will be used. | ||
|
||
--- | ||
|
||
Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric (and will be ignored). | ||
|
||
--- | ||
|
||
Code | ||
show_best(race_stc_res, metric = "brier_survival_integrated") | ||
Condition | ||
Warning: | ||
Metric "concordance_survival" was used to evaluate model candidates in the race but "brier_survival_integrated" has been chosen to rank the candidates. These results may not agree with the race. | ||
Error in `show_best()`: | ||
! "brier_survival_integrated" was not in the metric set. Please choose from: "concordance_survival". | ||
|
||
# show_best with censored data - static metric (+dyn) - W/L racing | ||
|
||
Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric (and will be ignored). | ||
|
||
--- | ||
|
||
No value of `metric` was given; "concordance_survival" will be used. | ||
|
||
--- | ||
|
||
Evaluation times are only required when dynmanic or integrated metrics are selected as the primary metric (and will be ignored). | ||
|
||
--- | ||
|
||
Code | ||
show_best(race_stc_res, metric = "brier_survival_integrated") | ||
Condition | ||
Warning: | ||
Metric "concordance_survival" was used to evaluate model candidates in the race but "brier_survival_integrated" has been chosen to rank the candidates. These results may not agree with the race. | ||
Error in `show_best()`: | ||
! "brier_survival_integrated" was not in the metric set. Please choose from: "concordance_survival" and "brier_survival". | ||
|
||
# show_best with censored data - dyn metric (+stc) - W/L racing | ||
|
||
No value of `metric` was given; "brier_survival" will be used. | ||
|
||
--- | ||
|
||
Metric "brier_survival" was used to evaluate model candidates in the race but "concordance_survival" has been chosen to rank the candidates. These results may not agree with the race. | ||
|
||
--- | ||
|
||
Code | ||
show_best(race_dyn_res, metric = "brier_survival", eval_time = 1) | ||
Condition | ||
Error in `show_best()`: | ||
! Evaluation time 1 is not in the results. | ||
|
||
--- | ||
|
||
Code | ||
show_best(race_dyn_res, metric = "brier_survival_integrated") | ||
Condition | ||
Warning: | ||
Metric "brier_survival" was used to evaluate model candidates in the race but "brier_survival_integrated" has been chosen to rank the candidates. These results may not agree with the race. | ||
Error in `show_best()`: | ||
! "brier_survival_integrated" was not in the metric set. Please choose from: "brier_survival" and "concordance_survival". | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
|
||
make_churn_cens_objects <- function(x) { | ||
suppressPackageStartupMessages(require("tidymodels")) | ||
suppressPackageStartupMessages(require("censored")) | ||
|
||
data("mlc_churn") | ||
|
||
mlc_churn <- | ||
mlc_churn %>% | ||
mutate( | ||
churned = ifelse(churn == "yes", 1, 0), | ||
event_time = Surv(account_length, churned) | ||
topepo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) %>% | ||
select(event_time, account_length, area_code, total_eve_calls) | ||
|
||
set.seed(6941) | ||
churn_split <- initial_split(mlc_churn) | ||
churn_tr <- training(churn_split) | ||
churn_te <- testing(churn_split) | ||
churn_rs <- vfold_cv(churn_tr) | ||
|
||
eval_times <- c(50, 100, 150) | ||
|
||
churn_rec <- | ||
recipe(event_time ~ ., data = churn_tr) %>% | ||
step_dummy(area_code) %>% | ||
step_normalize(all_predictors()) | ||
|
||
list(split = churn_split, train = churn_tr, test = churn_te, | ||
rs = churn_rs, times = eval_times, rec = churn_rec) | ||
|
||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
|
||
test_that("show_best with censored data - integrated metric - grid", { | ||
|
||
skip_if_not_installed("parsnip", minimum_version = "1.1.1.9007") | ||
skip_if_not_installed("tune", minimum_version = "1.1.2.9005") | ||
|
||
obj <- make_churn_cens_objects() | ||
|
||
g_ctrl <- control_grid(save_pred = TRUE) | ||
|
||
tree_spec <- | ||
decision_tree(cost_complexity = tune(), min_n = 2) %>% | ||
set_mode("censored regression") | ||
|
||
int_met <- metric_set(brier_survival_integrated) | ||
|
||
set.seed(1) | ||
grid_int_res <- | ||
tree_spec %>% | ||
tune_grid( | ||
event_time ~ ., | ||
resamples = obj$rs, | ||
grid = tibble(cost_complexity = 10^seq(-4, -2, by = .1)), | ||
control = g_ctrl, | ||
metrics = int_met, | ||
eval_time = obj$times | ||
) | ||
|
||
expect_equal( | ||
show_best(grid_int_res, metric = "brier_survival_integrated"), | ||
grid_int_res %>% | ||
collect_metrics() %>% | ||
arrange(mean) %>% | ||
slice_min(mean, n = 5) | ||
) | ||
expect_snapshot_warning( | ||
show_best(grid_int_res) | ||
) | ||
|
||
}) | ||
|
||
|
||
test_that("show_best with censored data - dynamic metric - bayes", { | ||
|
||
skip_if_not_installed("parsnip", minimum_version = "1.1.1.9007") | ||
skip_if_not_installed("tune", minimum_version = "1.1.2.9005") | ||
|
||
obj <- make_churn_cens_objects() | ||
|
||
tree_spec <- | ||
decision_tree(cost_complexity = tune(), min_n = 2) %>% | ||
set_mode("censored regression") | ||
|
||
dyn_met <- metric_set(brier_survival) | ||
|
||
set.seed(611) | ||
bayes_dyn_res <- | ||
tree_spec %>% | ||
tune_bayes( | ||
event_time ~ ., | ||
resamples = obj$rs, | ||
initial = 4, | ||
iter = 3, | ||
metrics = dyn_met, | ||
eval_time = 100 | ||
) | ||
|
||
expect_equal( | ||
show_best(bayes_dyn_res, metric = "brier_survival", eval_time = 100, n = 2), | ||
bayes_dyn_res %>% | ||
collect_metrics() %>% | ||
arrange(mean) %>% | ||
slice(1:2) | ||
) | ||
expect_snapshot_warning( | ||
show_best(bayes_dyn_res) | ||
) | ||
expect_snapshot( | ||
show_best(bayes_dyn_res, metric = "brier_survival", eval_time = 1), | ||
error = TRUE | ||
) | ||
expect_snapshot( | ||
show_best(bayes_dyn_res, metric = "brier_survival_integrated"), | ||
error = TRUE | ||
) | ||
|
||
}) | ||
|
||
|
||
test_that("show_best with censored data - static metric - anova racing", { | ||
|
||
skip_if_not_installed("parsnip", minimum_version = "1.1.1.9007") | ||
skip_if_not_installed("tune", minimum_version = "1.1.2.9005") | ||
skip_if_not_installed("finetune", minimum_version = "1.1.0.9004") | ||
|
||
obj <- make_churn_cens_objects() | ||
suppressPackageStartupMessages(library("finetune")) | ||
|
||
tree_spec <- | ||
decision_tree(cost_complexity = tune(), min_n = 2) %>% | ||
set_mode("censored regression") | ||
|
||
stc_met <- metric_set(concordance_survival) | ||
|
||
set.seed(22) | ||
race_stc_res <- | ||
tree_spec %>% | ||
tune_race_anova( | ||
event_time ~ ., | ||
resamples = obj$rs, | ||
grid = tibble(cost_complexity = 10^c(-1.4, -2.5, -3, -5)), | ||
metrics = stc_met | ||
) | ||
|
||
num_rs <- nrow(obj$rs) | ||
winners <- | ||
topepo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
race_stc_res %>% | ||
collect_metrics(summarize = FALSE) %>% | ||
count(.config) %>% | ||
filter(n == num_rs) %>% | ||
arrange(.config) %>% | ||
slice(1) %>% | ||
pluck(".config") | ||
|
||
expect_equal( | ||
sort(show_best(race_stc_res, metric = "concordance_survival", n = 1)$.config), | ||
winners | ||
topepo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
expect_snapshot_warning( | ||
show_best(race_stc_res) | ||
) | ||
expect_snapshot_warning( | ||
show_best(race_stc_res, metric = "concordance_survival", eval_time = 1) | ||
) | ||
expect_snapshot( | ||
show_best(race_stc_res, metric = "brier_survival_integrated"), | ||
error = TRUE | ||
) | ||
|
||
}) | ||
|
||
|
||
test_that("show_best with censored data - static metric (+dyn) - W/L racing", { | ||
|
||
skip_if_not_installed("parsnip", minimum_version = "1.1.1.9007") | ||
skip_if_not_installed("tune", minimum_version = "1.1.2.9005") | ||
skip_if_not_installed("finetune", minimum_version = "1.1.0.9004") | ||
|
||
obj <- make_churn_cens_objects() | ||
suppressPackageStartupMessages(library("finetune")) | ||
|
||
tree_spec <- | ||
decision_tree(cost_complexity = tune(), min_n = 2) %>% | ||
set_mode("censored regression") | ||
|
||
tree_param <- | ||
tree_spec %>% | ||
extract_parameter_set_dials() %>% | ||
update(cost_complexity = cost_complexity(c(-5, -1))) | ||
|
||
surv_met <- metric_set(concordance_survival, brier_survival) | ||
|
||
expect_snapshot_warning({ | ||
set.seed(326) | ||
race_stc_res <- | ||
tree_spec %>% | ||
tune_race_win_loss( | ||
event_time ~ ., | ||
resamples = obj$rs, | ||
grid = 10, | ||
metrics = surv_met, | ||
eval_time = 100, | ||
param_info = tree_param | ||
) | ||
}) | ||
|
||
num_rs <- nrow(obj$rs) | ||
winners <- | ||
race_stc_res %>% | ||
collect_metrics() %>% | ||
filter(.metric == "concordance_survival" & n == num_rs) %>% | ||
arrange(desc(mean)) %>% | ||
slice(1:5) %>% | ||
pluck(".config") | ||
|
||
expect_equal( | ||
show_best(race_stc_res, metric = "concordance_survival")$.config, | ||
winners | ||
) | ||
expect_snapshot_warning( | ||
show_best(race_stc_res) | ||
) | ||
expect_snapshot_warning( | ||
show_best(race_stc_res, metric = "concordance_survival", eval_time = 1) | ||
) | ||
expect_snapshot( | ||
show_best(race_stc_res, metric = "brier_survival_integrated"), | ||
error = TRUE | ||
) | ||
|
||
}) | ||
|
||
|
||
test_that("show_best with censored data - dyn metric (+stc) - W/L racing", { | ||
skip_if_not_installed("parsnip", minimum_version = "1.1.1.9007") | ||
skip_if_not_installed("tune", minimum_version = "1.1.2.9005") | ||
skip_if_not_installed("finetune", minimum_version = "1.1.0.9004") | ||
|
||
obj <- make_churn_cens_objects() | ||
suppressPackageStartupMessages(library("finetune")) | ||
|
||
boost_spec <- | ||
boost_tree(trees = tune()) %>% | ||
set_engine("mboost") %>% | ||
set_mode("censored regression") | ||
|
||
tree_spec <- | ||
decision_tree(cost_complexity = tune(), min_n = 2) %>% | ||
set_mode("censored regression") | ||
|
||
tree_param <- | ||
tree_spec %>% | ||
extract_parameter_set_dials() %>% | ||
update(cost_complexity = cost_complexity(c(-5, -1))) | ||
|
||
surv_met <- metric_set(brier_survival, concordance_survival) | ||
|
||
set.seed(326) | ||
race_dyn_res <- | ||
tree_spec %>% | ||
tune_race_win_loss( | ||
event_time ~ ., | ||
resamples = obj$rs, | ||
grid = 10, | ||
metrics = surv_met, | ||
eval_time = 100, | ||
param_info = tree_param | ||
) | ||
|
||
num_rs <- nrow(obj$rs) | ||
winners <- | ||
race_dyn_res %>% | ||
collect_metrics() %>% | ||
filter(.metric == "brier_survival" & n == num_rs) %>% | ||
arrange(mean) %>% | ||
slice(1:5) %>% | ||
pluck(".config") | ||
|
||
expect_equal( | ||
show_best(race_dyn_res, metric = "brier_survival")$.config, | ||
winners | ||
) | ||
expect_snapshot_warning( | ||
show_best(race_dyn_res) | ||
) | ||
expect_snapshot_warning( | ||
show_best(race_dyn_res, metric = "concordance_survival") | ||
) | ||
expect_snapshot( | ||
show_best(race_dyn_res, metric = "brier_survival", eval_time = 1), | ||
error = TRUE | ||
) | ||
expect_snapshot( | ||
show_best(race_dyn_res, metric = "brier_survival_integrated"), | ||
error = TRUE | ||
) | ||
|
||
}) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just error here instead of warn about the metric and then error anyway. I'll open an issue if you agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Open the issue on finetune tidymodels/finetune#89 -- this PR is ready to merge!