From 091ce4e36d254de6a423b75dcd1967d373a9b64d Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 16 Jul 2019 19:50:42 -0400 Subject: [PATCH] consolidate imports --- DESCRIPTION | 3 ++- R/C5.0.R | 31 +++++++++++++++++-------------- R/aaa.R | 18 ++++++++++++++++-- R/constructor.R | 1 - R/cubist.R | 5 ----- R/predict.R | 1 - 6 files changed, 35 insertions(+), 24 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index b00de79..667577c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -36,7 +36,8 @@ Imports: parsnip, utils, yardstick, - partykit + partykit, + tidypredict URL: https://github.com/topepo/baguette BugReports: https://github.com/topepo/baguette/issues RoxygenNote: 6.1.1 diff --git a/R/C5.0.R b/R/C5.0.R index 66b0d25..2cdd6dc 100644 --- a/R/C5.0.R +++ b/R/C5.0.R @@ -1,9 +1,3 @@ -#' @importFrom C50 C5.0 C5.0Control C5imp as.party.C5.0 -#' @importFrom rsample analysis -#' @importFrom purrr map map2 map_df -#' @importFrom tibble tibble -#' @importFrom parsnip decision_tree -#' @importFrom furrr future_map c5_bagger <- function(rs, opt, control, extract, ...) { @@ -13,7 +7,14 @@ c5_bagger <- function(rs, opt, control, extract, ...) { rs <- rs %>% - dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = c5_fit, spec = mod_spec)) + dplyr::mutate(model = iter( + fit_seed, + splits, + seed_fit, + .fn = c5_fit, + spec = mod_spec, + control = control + )) rs <- check_for_disaster(rs) @@ -74,15 +75,17 @@ make_c5_spec <- function(opt) { c5_spec } -#' @importFrom stats complete.cases -c5_fit <- function(split, spec) { +c5_fit <- function(split, spec, control = bag_control()) { ctrl <- parsnip::fit_control(catch = TRUE) - mod <- - parsnip::fit.model_spec(spec, - .outcome ~ ., - data = rsample::analysis(split), - control = ctrl) + + dat <- rsample::analysis(split) + + if (control$sampling == "down") { + dat <- down_sampler(dat) + } + + mod <- parsnip::fit.model_spec(spec, .outcome ~ ., data = dat, control = ctrl) mod } diff --git a/R/aaa.R b/R/aaa.R index 56ec8a5..2ffc303 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -1,8 +1,21 @@ #' @import rlang #' @import dplyr #' @import hardhat -#' @importFrom parsnip set_engine fit fit_xy fit_control +#' +#' @importFrom parsnip set_engine fit fit_xy fit_control mars decision_tree #' @importFrom utils globalVariables +#' @importFrom earth earth evimp +#' @importFrom rsample analysis bootstraps assessment +#' @importFrom purrr map map2 map_df map_dfr map_lgl +#' @importFrom tibble tibble as_tibble is_tibble +#' @importFrom furrr future_map future_map2 +#' @importFrom stats setNames sd predict complete.cases +#' @importFrom C50 C5.0 C5.0Control C5imp as.party.C5.0 +#' @importFrom rpart rpart +#' @importFrom partykit as.party.rpart +#' @importFrom Cubist cubist cubistControl +#' @importFrom withr with_seed +#' @importFrom tidypredict tidypredict_fit # ------------------------------------------------------------------------------ @@ -21,6 +34,7 @@ utils::globalVariables( ".pred", ".pred_class", "mod", - "value" + "value", + ".outcome" ) ) diff --git a/R/constructor.R b/R/constructor.R index bc3b323..48c35da 100644 --- a/R/constructor.R +++ b/R/constructor.R @@ -1,4 +1,3 @@ -#' @importFrom tibble is_tibble new_bagger <- function(model_df, imp, oob, control, opt, model, blueprint) { diff --git a/R/cubist.R b/R/cubist.R index 5d03a26..4690c90 100644 --- a/R/cubist.R +++ b/R/cubist.R @@ -1,8 +1,3 @@ -#' @importFrom Cubist cubist cubistControl -#' @importFrom rsample analysis -#' @importFrom purrr map_lgl map2 map_df -#' @importFrom tibble as_tibble -#' @importFrom furrr future_map2 cubist_bagger <- function(rs, opt, control, extract, ...) { diff --git a/R/predict.R b/R/predict.R index 78e7ebf..a8940ac 100644 --- a/R/predict.R +++ b/R/predict.R @@ -18,7 +18,6 @@ #' #' predict(mod, two_class_dat[1:6,], type = "class") #' predict(mod, two_class_dat[1:6,], type = "prob") -#' @importFrom purrr map_dfr #' @export predict.bagger <- function(object, new_data, type = NULL, ...) { type <- check_type(object, type)