diff --git a/R/buildExplainer.R b/R/buildExplainer.R index 7c086bb..ccfdc77 100644 --- a/R/buildExplainer.R +++ b/R/buildExplainer.R @@ -58,7 +58,7 @@ buildExplainer = function(xgb.model, trainingData, type = "binary", base_score = col_names = attr(trainingData, ".Dimnames")[[2]] cat('\nCreating the trees of the xgboost model...') - trees = xgb.model.dt.tree(col_names, model = xgb.model, n_first_tree = n_first_tree) + trees = xgb.model.dt.tree(col_names, model = xgb.model, trees = c(0:n_first_tree)) cat('\nGetting the leaf nodes for the training set observations...') nodes.train = predict(xgb.model,trainingData,predleaf =TRUE) diff --git a/R/explainPredictions.R b/R/explainPredictions.R index a79800e..2145008 100644 --- a/R/explainPredictions.R +++ b/R/explainPredictions.R @@ -51,7 +51,7 @@ explainPredictions = function(xgb.model, explainer ,data){ #Accepts data table of the breakdown for each leaf of each tree and the node matrix #Returns the breakdown for each prediction as a data table - nodes = predict(xgb.model,data,predleaf =TRUE) + nodes = predict(xgb.model,data,predleaf =TRUE, ntreelimit = max(explainer$tree) + 1) colnames = names(explainer)[1:(ncol(explainer)-2)]