Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
PirateGrunt committed May 22, 2018
1 parent bfcf902 commit ddfbe23
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions R/buildExplainer.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#' xgb.test.data <- xgb.DMatrix(test.data)
#'
#' param <- list(objective = "binary:logistic")
#' xgb.model <- xgboost(param =param, data = xgb.train.data, nrounds=3)
#' xgb.model <- xgboost(param = param, data = xgb.train.data, nrounds=3)
#'
#' col_names = colnames(X)
#'
Expand All @@ -47,20 +47,20 @@
#' trees = xgb.model.dt.tree(col_names, model = xgb.model)
#'
#' #### The XGBoost Explainer
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5, n_first_tree = xgb.model$best_ntreelimit - 1)
#' explainer = buildExplainer(xgb.model, xgb.train.data, type="binary", base_score = 0.5)
#' pred.breakdown = explainPredictions(xgb.model, explainer, xgb.test.data)
#'
#' showWaterfall(xgb.model, explainer, xgb.test.data, test.data, 2, type = "binary")
#' showWaterfall(xgb.model, explainer, xgb.test.data, test.data, 8, type = "binary")


buildExplainer = function(xgb.model, trainingData, type = "binary", base_score = 0.5, n_first_tree = NULL){
buildExplainer = function(xgb.model, trainingData, type = "binary", base_score = 0.5, n_first_tree = 1){

col_names = attr(trainingData, ".Dimnames")[[2]]
cat('\nCreating the trees of the xgboost model...')
trees = xgb.model.dt.tree(col_names, model = xgb.model, n_first_tree = n_first_tree)
cat('\nGetting the leaf nodes for the training set observations...')
nodes.train = predict(xgb.model,trainingData,predleaf =TRUE)
nodes.train = predict(xgb.model, trainingData, predleaf = TRUE)

cat('\nBuilding the Explainer...')
cat('\nSTEP 1 of 2')
Expand Down

0 comments on commit ddfbe23

Please sign in to comment.