Skip to content

Commit

Permalink
changes for #2
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jul 16, 2019
1 parent b52009d commit cefae1e
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
6 changes: 4 additions & 2 deletions R/C5.0.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
c5_bagger <- function(rs, opt, control, extract, ...) {

mod_spec <- make_c5_spec(opt)

iter <- get_iterator(control)

rs <-
rs %>%
dplyr::mutate(model = furrr::future_map2(fit_seed, splits, seed_fit,
.fn = c5_fit, spec = mod_spec))
dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = c5_fit, spec = mod_spec))

rs <- check_for_disaster(rs)

Expand Down
5 changes: 3 additions & 2 deletions R/cart.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ cart_bagger <- function(rs, opt, control, extract, ...) {
is_classif <- is.factor(rs$splits[[1]]$data$.outcome)
mod_spec <- make_cart_spec(is_classif, opt)

iter <- get_iterator(control)

rs <-
rs %>%
dplyr::mutate(model = furrr::future_map2(fit_seed, splits, seed_fit,
.fn = cart_fit, spec = mod_spec))
dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = cart_fit, spec = mod_spec))

rs <- check_for_disaster(rs)

Expand Down
5 changes: 3 additions & 2 deletions R/cubist.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

cubist_bagger <- function(rs, opt, control, extract, ...) {

iter <- get_iterator(control)

rs <-
rs %>%
dplyr::mutate(model = furrr::future_map2(fit_seed, splits, seed_fit,
.fn = cubist_fit, opt = opt))
dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = cubist_fit, opt = opt))

rs <- check_for_disaster(rs)

Expand Down
5 changes: 3 additions & 2 deletions R/mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ mars_bagger <- function(rs, opt, control, extract, ...) {
is_classif <- is.factor(rs$splits[[1]]$data$.outcome)
mod_spec <- make_mars_spec(is_classif, opt)

iter <- get_iterator(control)

rs <-
rs %>%
dplyr::mutate(model = furrr::future_map2(fit_seed, splits, seed_fit,
.fn = mars_fit, spec = mod_spec))
dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = mars_fit, spec = mod_spec))

rs <- check_for_disaster(rs)

Expand Down
12 changes: 12 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,15 @@ seed_fit <- function(seed, split, .fn, ...) {
}


# ------------------------------------------------------------------------------

get_iterator <- function(control) {
if (control$allow_parallel) {
iter <- furrr::future_map2
} else {
iter <- purrr::map2
}
iter
}


0 comments on commit cefae1e

Please sign in to comment.