Skip to content

Commit

Permalink
Move feature weight to skl parameters. (#9506)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Feb 24, 2025
1 parent 82bba31 commit 73e0df6
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 16 deletions.
14 changes: 11 additions & 3 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,10 @@ async def _fit_async(
feature_weights: Optional[_DaskCollection],
) -> _DaskCollection:
params = self.get_xgb_params()
model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)

dtrain, evals = await _async_wrap_evaluation_matrices(
client=self.client,
device=self.device,
Expand All @@ -1665,7 +1669,6 @@ async def _fit_async(
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
model, metric, params = self._configure_fit(xgb_model, params)
results = await self.client.sync(
_train_async,
asynchronous=True,
Expand Down Expand Up @@ -1729,6 +1732,10 @@ async def _fit_async(
feature_weights: Optional[_DaskCollection],
) -> "DaskXGBClassifier":
params = self.get_xgb_params()
model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)

dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
device=self.device,
Expand Down Expand Up @@ -1773,7 +1780,6 @@ async def _fit_async(
obj: Optional[Callable] = _objective_decorator(self.objective)
else:
obj = None
model, metric, params = self._configure_fit(xgb_model, params)
results = await self.client.sync(
_train_async,
asynchronous=True,
Expand Down Expand Up @@ -1953,6 +1959,9 @@ async def _fit_async(
feature_weights: Optional[_DaskCollection],
) -> "DaskXGBRanker":
params = self.get_xgb_params()
model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
device=self.device,
Expand All @@ -1974,7 +1983,6 @@ async def _fit_async(
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
model, metric, params = self._configure_fit(xgb_model, params)
results = await self.client.sync(
_train_async,
asynchronous=True,
Expand Down
51 changes: 40 additions & 11 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,13 @@ def task(i: int) -> float:
Used for specifying feature types without constructing a dataframe. See
:py:class:`DMatrix` for details.
max_cat_to_onehot : {Optional[int]}
feature_weights : Optional[ArrayLike]
Weight for each feature, defines the probability of each feature being selected
when colsample is being used. All values must be greater than 0, otherwise a
`ValueError` is thrown.
max_cat_to_onehot : Optional[int]
.. versionadded:: 1.6.0
Expand Down Expand Up @@ -607,7 +613,7 @@ def _wrap_evaluation_matrices(
qid: Optional[Any],
sample_weight: Optional[Any],
base_margin: Optional[Any],
feature_weights: Optional[Any],
feature_weights: Optional[ArrayLike],
eval_set: Optional[Sequence[Tuple[Any, Any]]],
sample_weight_eval_set: Optional[Sequence[Any]],
base_margin_eval_set: Optional[Sequence[Any]],
Expand Down Expand Up @@ -753,6 +759,7 @@ def __init__(
validate_parameters: Optional[bool] = None,
enable_categorical: bool = False,
feature_types: Optional[FeatureTypes] = None,
feature_weights: Optional[ArrayLike] = None,
max_cat_to_onehot: Optional[int] = None,
max_cat_threshold: Optional[int] = None,
multi_strategy: Optional[str] = None,
Expand Down Expand Up @@ -799,6 +806,7 @@ def __init__(
self.validate_parameters = validate_parameters
self.enable_categorical = enable_categorical
self.feature_types = feature_types
self.feature_weights = feature_weights
self.max_cat_to_onehot = max_cat_to_onehot
self.max_cat_threshold = max_cat_threshold
self.multi_strategy = multi_strategy
Expand Down Expand Up @@ -895,6 +903,7 @@ def _wrapper_params(self) -> Set[str]:
"early_stopping_rounds",
"callbacks",
"feature_types",
"feature_weights",
}
return wrapper_specific

Expand Down Expand Up @@ -1065,10 +1074,12 @@ def _configure_fit(
self,
booster: Optional[Union[Booster, "XGBModel", str]],
params: Dict[str, Any],
feature_weights: Optional[ArrayLike],
) -> Tuple[
Optional[Union[Booster, str, "XGBModel"]],
Optional[Metric],
Dict[str, Any],
Optional[ArrayLike],
]:
"""Configure parameters for :py:meth:`fit`."""
if isinstance(booster, XGBModel):
Expand Down Expand Up @@ -1101,13 +1112,23 @@ def _duplicated(parameter: str) -> None:
else:
params.update({"eval_metric": self.eval_metric})

if feature_weights is not None:
_deprecated("feature_weights")
if feature_weights is not None and self.feature_weights is not None:
_duplicated("feature_weights")
feature_weights = (
self.feature_weights
if self.feature_weights is not None
else feature_weights
)

tree_method = params.get("tree_method", None)
if self.enable_categorical and tree_method == "exact":
raise ValueError(
"Experimental support for categorical data is not implemented for"
" current tree method yet."
)
return model, metric, params
return model, metric, params, feature_weights

def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
Expand Down Expand Up @@ -1184,12 +1205,19 @@ def fit(
A list of the form [M_1, M_2, ..., M_n], where each M_i is an array like
object storing base margin for the i-th validation set.
feature_weights :
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown.
.. deprecated:: 3.0.0
Use `feature_weights` in :py:meth:`__init__` or :py:meth:`set_params`
instead.
"""
with config_context(verbosity=self.verbosity):
params = self.get_xgb_params()
model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)

evals_result: TrainingCallback.EvalsLog = {}
train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
Expand All @@ -1209,15 +1237,13 @@ def fit(
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
params = self.get_xgb_params()

if callable(self.objective):
obj: Optional[Objective] = _objective_decorator(self.objective)
params["objective"] = "reg:squarederror"
else:
obj = None

model, metric, params = self._configure_fit(xgb_model, params)
self._Booster = train(
params,
train_dmatrix,
Expand Down Expand Up @@ -1631,7 +1657,9 @@ def fit(
params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_

model, metric, params = self._configure_fit(xgb_model, params)
model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)
train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
X=X,
Expand Down Expand Up @@ -2148,8 +2176,9 @@ def fit(
evals_result: TrainingCallback.EvalsLog = {}
params = self.get_xgb_params()

model, metric, params = self._configure_fit(xgb_model, params)

model, metric, params, feature_weights = self._configure_fit(
xgb_model, params, feature_weights
)
self._Booster = train(
params,
train_dmatrix,
Expand Down
1 change: 1 addition & 0 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def __init__(self) -> None:
repartition_random_shuffle=False,
feature_names=None,
feature_types=None,
feature_weights=None,
arbitrary_params_dict={},
launch_tracker_on_driver=True,
)
Expand Down
1 change: 1 addition & 0 deletions python-package/xgboost/spark/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def pred_contribs(
missing=model.missing,
nthread=model.n_jobs,
feature_types=model.feature_types,
feature_weights=model.feature_weights,
enable_categorical=model.enable_categorical,
)
return model.get_booster().predict(
Expand Down
8 changes: 6 additions & 2 deletions python-package/xgboost/testing/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def get_feature_weights(
"""Get feature weights using the demo parser."""
with tempfile.TemporaryDirectory() as tmpdir:
colsample_bynode = 0.5
reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode)
reg = model(
tree_method=tree_method,
colsample_bynode=colsample_bynode,
feature_weights=fw,
)

reg.fit(X, y, feature_weights=fw)
reg.fit(X, y)
model_path = os.path.join(tmpdir, "model.json")
reg.save_model(model_path)
with open(model_path, "r", encoding="utf-8") as fd:
Expand Down
4 changes: 4 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,10 @@ def test_feature_weights(tree_method):
assert poly_increasing[0] > 0.08
assert poly_decreasing[0] < -0.08

reg = xgb.XGBRegressor(feature_weights=np.ones((kCols, )))
with pytest.raises(ValueError, match="Use the one in"):
reg.fit(X, y, feature_weights=np.ones((kCols, )))


def run_boost_from_prediction_binary(tree_method, X, y, as_frame: Optional[Callable]):
"""
Expand Down

0 comments on commit 73e0df6

Please sign in to comment.