Skip to content

Commit 712e39d

Browse files
authored
Cleanup CPU predict function. (#11139)
1 parent b4a7cd1 commit 712e39d

File tree

19 files changed

+324
-465
lines changed

19 files changed

+324
-465
lines changed

include/xgboost/gbm.h

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2014-2023 by XGBoost Contributors
2+
* Copyright 2014-2025, XGBoost Contributors
33
* \file gbm.h
44
* \brief Interface of gradient booster,
55
* that learns through gradient statistics.
@@ -15,10 +15,8 @@
1515
#include <xgboost/model.h>
1616

1717
#include <vector>
18-
#include <utility>
1918
#include <string>
2019
#include <functional>
21-
#include <unordered_map>
2220
#include <memory>
2321

2422
namespace xgboost {
@@ -42,13 +40,13 @@ class GradientBooster : public Model, public Configurable {
4240
public:
4341
/*! \brief virtual destructor */
4442
~GradientBooster() override = default;
45-
/*!
46-
* \brief Set the configuration of gradient boosting.
43+
/**
44+
* @brief Set the configuration of gradient boosting.
4745
* User must call configure once before InitModel and Training.
4846
*
49-
* \param cfg configurations on both training and model parameters.
47+
* @param cfg configurations on both training and model parameters.
5048
*/
51-
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
49+
virtual void Configure(Args const& cfg) = 0;
5250
/*!
5351
* \brief load model from stream
5452
* \param fi input stream.
@@ -117,21 +115,6 @@ class GradientBooster : public Model, public Configurable {
117115
bst_layer_t) const {
118116
LOG(FATAL) << "Inplace predict is not supported by the current booster.";
119117
}
120-
/*!
121-
* \brief online prediction function, predict score for one instance at a time
122-
* NOTE: use the batch prediction interface if possible, batch prediction is usually
123-
* more efficient than online prediction
124-
* This function is NOT threadsafe, make sure you only call from one thread
125-
*
126-
* \param inst the instance you want to predict
127-
* \param out_preds output vector to hold the predictions
128-
* \param layer_begin Beginning of boosted tree layer used for prediction.
129-
* \param layer_end End of booster layer. 0 means do not limit trees.
130-
* \sa Predict
131-
*/
132-
virtual void PredictInstance(const SparsePage::Inst& inst,
133-
std::vector<bst_float>* out_preds,
134-
unsigned layer_begin, unsigned layer_end) = 0;
135118
/*!
136119
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
137120
* this is only valid in gbtree predictor

include/xgboost/predictor.h

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2024, XGBoost Contributors
2+
* Copyright 2017-2025, XGBoost Contributors
33
* \file predictor.h
44
* \brief Interface of predictor,
55
* performs predictions for a gradient booster.
@@ -28,7 +28,7 @@ namespace xgboost {
2828
*/
2929
struct PredictionCacheEntry {
3030
// A storage for caching prediction values
31-
HostDeviceVector<bst_float> predictions;
31+
HostDeviceVector<float> predictions;
3232
// The version of current cache, corresponding number of layers of trees
3333
std::uint32_t version{0};
3434

@@ -91,7 +91,7 @@ class Predictor {
9191
* \param out_predt Prediction vector to be initialized.
9292
* \param model Tree model used for prediction.
9393
*/
94-
virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_predt,
94+
virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector<float>* out_predt,
9595
const gbm::GBTreeModel& model) const;
9696

9797
/**
@@ -105,8 +105,8 @@ class Predictor {
105105
* \param tree_end The tree end index.
106106
*/
107107
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
108-
const gbm::GBTreeModel& model, uint32_t tree_begin,
109-
uint32_t tree_end = 0) const = 0;
108+
gbm::GBTreeModel const& model, bst_tree_t tree_begin,
109+
bst_tree_t tree_end = 0) const = 0;
110110

111111
/**
112112
* \brief Inplace prediction.
@@ -123,25 +123,7 @@ class Predictor {
123123
*/
124124
virtual bool InplacePredict(std::shared_ptr<DMatrix> p_fmat, const gbm::GBTreeModel& model,
125125
float missing, PredictionCacheEntry* out_preds,
126-
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
127-
/**
128-
* \brief online prediction function, predict score for one instance at a time
129-
* NOTE: use the batch prediction interface if possible, batch prediction is
130-
* usually more efficient than online prediction This function is NOT
131-
* threadsafe, make sure you only call from one thread.
132-
*
133-
* \param inst The instance to predict.
134-
* \param [in,out] out_preds The output preds.
135-
* \param model The model to predict from
136-
* \param tree_end (Optional) The tree end index.
137-
* \param is_column_split (Optional) If the data is split column-wise.
138-
*/
139-
140-
virtual void PredictInstance(const SparsePage::Inst& inst,
141-
std::vector<bst_float>* out_preds,
142-
const gbm::GBTreeModel& model,
143-
unsigned tree_end = 0,
144-
bool is_column_split = false) const = 0;
126+
bst_tree_t tree_begin = 0, bst_tree_t tree_end = 0) const = 0;
145127

146128
/**
147129
* \brief predict the leaf index of each tree, the output will be nsample *
@@ -153,9 +135,8 @@ class Predictor {
153135
* \param tree_end (Optional) The tree end index.
154136
*/
155137

156-
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
157-
const gbm::GBTreeModel& model,
158-
unsigned tree_end = 0) const = 0;
138+
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<float>* out_preds,
139+
gbm::GBTreeModel const& model, bst_tree_t tree_end = 0) const = 0;
159140

160141
/**
161142
* \brief feature contributions to individual predictions; the output will be
@@ -172,18 +153,17 @@ class Predictor {
172153
* \param condition_feature Feature to condition on (i.e. fix) during calculations.
173154
*/
174155

175-
virtual void
176-
PredictContribution(DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
177-
const gbm::GBTreeModel &model, unsigned tree_end = 0,
178-
std::vector<bst_float> const *tree_weights = nullptr,
179-
bool approximate = false, int condition = 0,
180-
unsigned condition_feature = 0) const = 0;
181-
182-
virtual void PredictInteractionContributions(
183-
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
184-
const gbm::GBTreeModel &model, unsigned tree_end = 0,
185-
std::vector<bst_float> const *tree_weights = nullptr,
186-
bool approximate = false) const = 0;
156+
virtual void PredictContribution(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
157+
gbm::GBTreeModel const& model, bst_tree_t tree_end = 0,
158+
std::vector<float> const* tree_weights = nullptr,
159+
bool approximate = false, int condition = 0,
160+
unsigned condition_feature = 0) const = 0;
161+
162+
virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
163+
gbm::GBTreeModel const& model,
164+
bst_tree_t tree_end = 0,
165+
std::vector<float> const* tree_weights = nullptr,
166+
bool approximate = false) const = 0;
187167

188168
/**
189169
* \brief Creates a new Predictor*.

include/xgboost/tree_model.h

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2014-2024, XGBoost Contributors
2+
* Copyright 2014-2025, XGBoost Contributors
33
* \file tree_model.h
44
* \brief model structure for tree
55
* \author Tianqi Chen
@@ -23,7 +23,6 @@
2323
#include <memory> // for make_unique
2424
#include <stack>
2525
#include <string>
26-
#include <tuple>
2726
#include <vector>
2827

2928
namespace xgboost {
@@ -562,7 +561,7 @@ class RegTree : public Model {
562561
* \brief fill the vector with sparse vector
563562
* \param inst The sparse instance to fill.
564563
*/
565-
void Fill(const SparsePage::Inst& inst);
564+
void Fill(SparsePage::Inst const& inst);
566565

567566
/*!
568567
* \brief drop the trace after fill, must be called after fill.
@@ -587,18 +586,17 @@ class RegTree : public Model {
587586
*/
588587
[[nodiscard]] bool IsMissing(size_t i) const;
589588
[[nodiscard]] bool HasMissing() const;
589+
void HasMissing(bool has_missing) { this->has_missing_ = has_missing; }
590590

591+
[[nodiscard]] common::Span<float> Data() { return data_; }
591592

592593
private:
593-
/*!
594-
* \brief a union value of value and flag
595-
* when flag == -1, this indicate the value is missing
594+
/**
595+
* @brief A dense vector for a single sample.
596+
*
597+
* It's nan if the value is missing.
596598
*/
597-
union Entry {
598-
bst_float fvalue;
599-
int flag;
600-
};
601-
std::vector<Entry> data_;
599+
std::vector<float> data_;
602600
bool has_missing_;
603601
};
604602

@@ -793,46 +791,35 @@ class RegTree : public Model {
793791
};
794792

795793
inline void RegTree::FVec::Init(size_t size) {
796-
Entry e; e.flag = -1;
797794
data_.resize(size);
798-
std::fill(data_.begin(), data_.end(), e);
795+
std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
799796
has_missing_ = true;
800797
}
801798

802-
inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
803-
size_t feature_count = 0;
804-
for (auto const& entry : inst) {
805-
if (entry.index >= data_.size()) {
806-
continue;
807-
}
808-
data_[entry.index].fvalue = entry.fvalue;
809-
++feature_count;
799+
inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) {
800+
auto p_data = inst.data();
801+
auto p_out = data_.data();
802+
803+
for (std::size_t i = 0, n = inst.size(); i < n; ++i) {
804+
auto const& entry = p_data[i];
805+
p_out[entry.index] = entry.fvalue;
810806
}
811-
has_missing_ = data_.size() != feature_count;
807+
has_missing_ = data_.size() != inst.size();
812808
}
813809

814-
inline void RegTree::FVec::Drop() {
815-
Entry e{};
816-
e.flag = -1;
817-
std::fill_n(data_.data(), data_.size(), e);
818-
has_missing_ = true;
819-
}
810+
inline void RegTree::FVec::Drop() { this->Init(this->Size()); }
820811

821812
inline size_t RegTree::FVec::Size() const {
822813
return data_.size();
823814
}
824815

825-
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
826-
return data_[i].fvalue;
816+
inline float RegTree::FVec::GetFvalue(size_t i) const {
817+
return data_[i];
827818
}
828819

829-
inline bool RegTree::FVec::IsMissing(size_t i) const {
830-
return data_[i].flag == -1;
831-
}
820+
inline bool RegTree::FVec::IsMissing(size_t i) const { return std::isnan(data_[i]); }
832821

833-
inline bool RegTree::FVec::HasMissing() const {
834-
return has_missing_;
835-
}
822+
inline bool RegTree::FVec::HasMissing() const { return has_missing_; }
836823

837824
// Multi-target tree not yet implemented error
838825
inline StringView MTNotImplemented() {

plugin/sycl/predictor/predictor.cc

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ class Predictor : public xgboost::Predictor {
201201
}
202202

203203
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
204-
const gbm::GBTreeModel &model, uint32_t tree_begin,
205-
uint32_t tree_end = 0) const override {
204+
const gbm::GBTreeModel &model, bst_tree_t tree_begin,
205+
bst_tree_t tree_end = 0) const override {
206206
auto* out_preds = &predts->predictions;
207207
out_preds->SetDevice(ctx_->Device());
208208
if (tree_end == 0) {
@@ -221,28 +221,20 @@ class Predictor : public xgboost::Predictor {
221221

222222
bool InplacePredict(std::shared_ptr<DMatrix> p_m,
223223
const gbm::GBTreeModel &model, float missing,
224-
PredictionCacheEntry *out_preds, uint32_t tree_begin,
225-
unsigned tree_end) const override {
224+
PredictionCacheEntry *out_preds, bst_tree_t tree_begin,
225+
bst_tree_t tree_end) const override {
226226
LOG(WARNING) << "InplacePredict is not yet implemented for SYCL. CPU Predictor is used.";
227227
return cpu_predictor->InplacePredict(p_m, model, missing, out_preds, tree_begin, tree_end);
228228
}
229229

230-
void PredictInstance(const SparsePage::Inst& inst,
231-
std::vector<bst_float>* out_preds,
232-
const gbm::GBTreeModel& model, unsigned ntree_limit,
233-
bool is_column_split) const override {
234-
LOG(WARNING) << "PredictInstance is not yet implemented for SYCL. CPU Predictor is used.";
235-
cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit, is_column_split);
236-
}
237-
238230
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
239-
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
231+
const gbm::GBTreeModel& model, bst_tree_t ntree_limit) const override {
240232
LOG(WARNING) << "PredictLeaf is not yet implemented for SYCL. CPU Predictor is used.";
241233
cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit);
242234
}
243235

244236
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
245-
const gbm::GBTreeModel& model, uint32_t ntree_limit,
237+
const gbm::GBTreeModel& model, bst_tree_t ntree_limit,
246238
const std::vector<bst_float>* tree_weights,
247239
bool approximate, int condition,
248240
unsigned condition_feature) const override {
@@ -252,7 +244,7 @@ class Predictor : public xgboost::Predictor {
252244
}
253245

254246
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
255-
const gbm::GBTreeModel& model, unsigned ntree_limit,
247+
const gbm::GBTreeModel& model, bst_tree_t ntree_limit,
256248
const std::vector<bst_float>* tree_weights,
257249
bool approximate) const override {
258250
LOG(WARNING) << "PredictInteractionContributions is not yet implemented for SYCL. "

src/common/column_matrix.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2024, XGBoost Contributors
2+
* Copyright 2017-2025, XGBoost Contributors
33
* \file column_matrix.h
44
* \brief Utility for fast column-wise access
55
* \author Philip Cho
@@ -45,15 +45,15 @@ class Column {
4545
virtual ~Column() = default;
4646

4747
[[nodiscard]] bst_bin_t GetGlobalBinIdx(size_t idx) const {
48-
return index_base_ + static_cast<bst_bin_t>(index_[idx]);
48+
return index_base_ + static_cast<bst_bin_t>(index_.data()[idx]);
4949
}
5050

5151
/* returns number of elements in column */
5252
[[nodiscard]] size_t Size() const { return index_.size(); }
5353

5454
private:
5555
/* bin indexes in range [0, max_bins - 1] */
56-
common::Span<const BinIdxType> index_;
56+
common::Span<BinIdxType const> index_;
5757
/* bin index offset for specific feature */
5858
bst_bin_t const index_base_;
5959
};

src/common/hist_util.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ class HistogramCuts {
8383
[[nodiscard]] bst_bin_t FeatureBins(bst_feature_t feature) const {
8484
return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature];
8585
}
86-
[[nodiscard]] bst_feature_t NumFeatures() const {
87-
CHECK_EQ(this->min_vals_.Size(), this->cut_ptrs_.Size() - 1);
88-
return this->min_vals_.Size();
89-
}
86+
[[nodiscard]] bst_feature_t NumFeatures() const { return this->cut_ptrs_.Size() - 1; }
9087

9188
std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_.ConstHostVector(); }
9289
std::vector<float> const& Values() const { return cut_values_.ConstHostVector(); }

src/common/ref_resource_view.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost Contributors
2+
* Copyright 2023-2025, XGBoost Contributors
33
*/
44
#ifndef XGBOOST_COMMON_REF_RESOURCE_VIEW_H_
55
#define XGBOOST_COMMON_REF_RESOURCE_VIEW_H_
@@ -88,6 +88,14 @@ class RefResourceView {
8888

8989
[[nodiscard]] value_type& operator[](size_type i) { return ptr_[i]; }
9090
[[nodiscard]] value_type const& operator[](size_type i) const { return ptr_[i]; }
91+
[[nodiscard]] value_type& at(size_type i) { // NOLINT
92+
SPAN_LT(i, this->size_);
93+
return ptr_[i];
94+
}
95+
[[nodiscard]] value_type const& at(size_type i) const { // NOLINT
96+
SPAN_LT(i, this->size_);
97+
return ptr_[i];
98+
}
9199

92100
/**
93101
* @brief Get the underlying resource.

0 commit comments

Comments
 (0)