Skip to content

Commit

Permalink
added code to expose c_api_pred_contrib in the R package (microsoft#1259
Browse files Browse the repository at this point in the history
)

* added code to expose c_api_pred_contrib in the R package

* removed Rprintf

* reverted to previous version of install.libs.R
  • Loading branch information
gravesee authored and guolinke committed Mar 11, 2018
1 parent 48ff86e commit cb9fabd
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,4 @@ lightgbm.model

# duplicate version file
python-package/lightgbm/VERSION.txt
.Rproj.user
3 changes: 3 additions & 0 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ Booster <- R6Class(
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
#' logistic regression would result in predictions for log-odds instead of probabilities.
#' @param predleaf whether predict leaf index instead.
#' @param predcontrib return per-feature contributions for each record.
#' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case.
Expand Down Expand Up @@ -655,6 +656,7 @@ predict.lgb.Booster <- function(object, data,
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
predcontrib = FALSE,
header = FALSE,
reshape = FALSE, ...) {

Expand All @@ -668,6 +670,7 @@ predict.lgb.Booster <- function(object, data,
num_iteration,
rawscore,
predleaf,
predcontrib,
header,
reshape, ...)
}
Expand Down
8 changes: 7 additions & 1 deletion R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Predictor <- R6Class(
num_iteration = NULL,
rawscore = FALSE,
predleaf = FALSE,
predcontrib = FALSE,
header = FALSE,
reshape = FALSE) {

Expand All @@ -86,6 +87,7 @@ Predictor <- R6Class(
as.integer(header),
as.integer(rawscore),
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration),
private$params,
lgb.c_str(tmp_filename))
Expand All @@ -99,6 +101,7 @@ Predictor <- R6Class(

# Not a file, we need to predict from R object
num_row <- nrow(data)

npred <- 0L

# Check number of predictions to do
Expand All @@ -108,6 +111,7 @@ Predictor <- R6Class(
as.integer(num_row),
as.integer(rawscore),
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration))

# Pre-allocate empty vector
Expand All @@ -123,6 +127,7 @@ Predictor <- R6Class(
as.integer(ncol(data)),
as.integer(rawscore),
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration),
private$params)

Expand All @@ -142,6 +147,7 @@ Predictor <- R6Class(
nrow(data),
as.integer(rawscore),
as.integer(predleaf),
as.integer(predcontrib),
as.integer(num_iteration),
private$params)

Expand All @@ -165,7 +171,7 @@ Predictor <- R6Class(

# Data reshaping

if (predleaf) {
if (predleaf | predcontrib) {

# Predict leaves only, reshaping is mandatory
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE data_has_header,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename,
Expand All @@ -407,6 +408,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE out_len,
LGBM_SE call_state);
Expand Down Expand Up @@ -438,6 +440,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
Expand All @@ -464,6 +467,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE ncol,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
Expand Down
17 changes: 12 additions & 5 deletions src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,14 +479,17 @@ LGBM_SE LGBM_BoosterGetPredict_R(LGBM_SE handle,
R_API_END();
}

int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx) {
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx, LGBM_SE is_predcontrib) {
int pred_type = C_API_PREDICT_NORMAL;
if (R_AS_INT(is_rawscore)) {
pred_type = C_API_PREDICT_RAW_SCORE;
}
if (R_AS_INT(is_leafidx)) {
pred_type = C_API_PREDICT_LEAF_INDEX;
}
if (R_AS_INT(is_predcontrib)) {
pred_type = C_API_PREDICT_CONTRIB;
}
return pred_type;
}

Expand All @@ -495,12 +498,13 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE data_has_header,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename)));
Expand All @@ -511,11 +515,12 @@ LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE out_len,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
pred_type, R_AS_INT(num_iteration), &len));
Expand All @@ -532,13 +537,14 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE num_row,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
LGBM_SE call_state) {

R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);

const int* p_indptr = R_INT_PTR(indptr);
const int* p_indices = R_INT_PTR(indices);
Expand All @@ -562,13 +568,14 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE num_col,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE is_predcontrib,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
LGBM_SE call_state) {

R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);

int32_t nrow = R_AS_INT(num_row);
int32_t ncol = R_AS_INT(num_col);
Expand Down

0 comments on commit cb9fabd

Please sign in to comment.