Skip to content

Commit

Permalink
consolidate imports
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jul 16, 2019
1 parent cefae1e commit 091ce4e
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 24 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ Imports:
parsnip,
utils,
yardstick,
partykit
partykit,
tidypredict
URL: https://github.com/topepo/baguette
BugReports: https://github.com/topepo/baguette/issues
RoxygenNote: 6.1.1
31 changes: 17 additions & 14 deletions R/C5.0.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
#' @importFrom C50 C5.0 C5.0Control C5imp as.party.C5.0
#' @importFrom rsample analysis
#' @importFrom purrr map map2 map_df
#' @importFrom tibble tibble
#' @importFrom parsnip decision_tree
#' @importFrom furrr future_map

c5_bagger <- function(rs, opt, control, extract, ...) {

Expand All @@ -13,7 +7,14 @@ c5_bagger <- function(rs, opt, control, extract, ...) {

rs <-
rs %>%
dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = c5_fit, spec = mod_spec))
dplyr::mutate(model = iter(
fit_seed,
splits,
seed_fit,
.fn = c5_fit,
spec = mod_spec,
control = control
))

rs <- check_for_disaster(rs)

Expand Down Expand Up @@ -74,15 +75,17 @@ make_c5_spec <- function(opt) {
c5_spec
}

#' @importFrom stats complete.cases

c5_fit <- function(split, spec) {
c5_fit <- function(split, spec, control = bag_control()) {
ctrl <- parsnip::fit_control(catch = TRUE)
mod <-
parsnip::fit.model_spec(spec,
.outcome ~ .,
data = rsample::analysis(split),
control = ctrl)

dat <- rsample::analysis(split)

if (control$sampling == "down") {
dat <- down_sampler(dat)
}

mod <- parsnip::fit.model_spec(spec, .outcome ~ ., data = dat, control = ctrl)
mod
}

Expand Down
18 changes: 16 additions & 2 deletions R/aaa.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
#' @import rlang
#' @import dplyr
#' @import hardhat
#' @importFrom parsnip set_engine fit fit_xy fit_control
#'
#' @importFrom parsnip set_engine fit fit_xy fit_control mars decision_tree
#' @importFrom utils globalVariables
#' @importFrom earth earth evimp
#' @importFrom rsample analysis bootstraps assessment
#' @importFrom purrr map map2 map_df map_dfr map_lgl
#' @importFrom tibble tibble as_tibble is_tibble
#' @importFrom furrr future_map future_map2
#' @importFrom stats setNames sd predict complete.cases
#' @importFrom C50 C5.0 C5.0Control C5imp as.party.C5.0
#' @importFrom rpart rpart
#' @importFrom partykit as.party.rpart
#' @importFrom Cubist cubist cubistControl
#' @importFrom withr with_seed
#' @importFrom tidypredict tidypredict_fit

# ------------------------------------------------------------------------------

Expand All @@ -21,6 +34,7 @@ utils::globalVariables(
".pred",
".pred_class",
"mod",
"value"
"value",
".outcome"
)
)
1 change: 0 additions & 1 deletion R/constructor.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#' @importFrom tibble is_tibble

new_bagger <- function(model_df, imp, oob, control, opt, model, blueprint) {

Expand Down
5 changes: 0 additions & 5 deletions R/cubist.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
#' @importFrom Cubist cubist cubistControl
#' @importFrom rsample analysis
#' @importFrom purrr map_lgl map2 map_df
#' @importFrom tibble as_tibble
#' @importFrom furrr future_map2

cubist_bagger <- function(rs, opt, control, extract, ...) {

Expand Down
1 change: 0 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#'
#' predict(mod, two_class_dat[1:6,], type = "class")
#' predict(mod, two_class_dat[1:6,], type = "prob")
#' @importFrom purrr map_dfr
#' @export
predict.bagger <- function(object, new_data, type = NULL, ...) {
type <- check_type(object, type)
Expand Down

0 comments on commit 091ce4e

Please sign in to comment.