Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add catboost integration tests #17931

Open
wants to merge 18 commits into
base: branch-25.04
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ files:
- py_version
- test_base
- test_xgboost
test_catboost:
output: none
includes:
- cuda_version
- py_version
- test_base
- test_catboost
test_cuml:
output: none
includes:
Expand Down Expand Up @@ -243,6 +250,15 @@ dependencies:
- scikit-learn
- pip
- xgboost>=2.0.1
test_catboost:
common:
- output_types: conda
packages:
# TODO: Remove numpy pinning once https://github.com/catboost/catboost/issues/2671 is resolved
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See this paragraph from the numpy 2 release

Breaking changes to the NumPy ABI. As a result, binaries of packages
that use the NumPy C API and were built against a NumPy 1.xx release
will not work with NumPy 2.0. On import, such packages will see an
ImportError with a message about binary incompatibility.

- numpy>=1.23,<2.0.0
- scipy
- scikit-learn
- catboost
test_cuml:
common:
- output_types: conda
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
import numpy as np
import pandas as pd
import pytest
from catboost import CatBoostClassifier, CatBoostRegressor, Pool
from sklearn.datasets import make_classification, make_regression

rng = np.random.default_rng(seed=42)


def assert_catboost_equal(expect, got, rtol=1e-7, atol=0.0):
if isinstance(expect, (tuple, list)):
assert len(expect) == len(got)
for e, g in zip(expect, got):
assert_catboost_equal(e, g, rtol, atol)
elif isinstance(expect, np.ndarray):
np.testing.assert_allclose(expect, got, rtol=rtol, atol=atol)
elif isinstance(expect, pd.DataFrame):
pd.testing.assert_frame_equal(expect, got)
elif isinstance(expect, pd.Series):
pd.testing.assert_series_equal(expect, got)
else:
assert expect == got


pytestmark = pytest.mark.assert_eq(fn=assert_catboost_equal)


@pytest.fixture
def regression_data():
X, y = make_regression(n_samples=100, n_features=10, random_state=42)
return pd.DataFrame(X), pd.Series(y)


@pytest.fixture
def classification_data():
X, y = make_classification(
n_samples=100, n_features=10, n_classes=2, random_state=42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_samples=100, n_features=10, n_classes=2, random_state=42
n_samples=1_000, n_features=10, n_classes=2, random_state=42

You may want to use slightly more data, here an in regression_data(). There are some types of encoding and data access bugs that will only show up in certain codepaths in CatBoost that are exercised when there are enough splits per tree.

I've seen this before in LightGBM and XGBoost... someone will write a test that fits on a very small dataset and it'll look like nothing went wrong, only to later find that actually the dataset was so small that the model was just a collection of decision stumps (no splits), and so the test could never catch issues like "this encoding doesn't preserve NAs" or "these outputs are different because of numerical precision issues".

)
Copy link
Member

@jameslamb jameslamb Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_classification() returns a dataset that has only continuous features.

from sklearn.datasets import make_classification

X, y = make_classification(
    n_samples=100, n_features=10, n_classes=2, random_state=42
)
X
array([[-1.14052601,  1.35970566,  0.86199147,  0.84609208,  0.60600995,
        -1.55662917,  1.75479418,  1.69645637, -1.28042935, -2.08192941],
...

For catboost in particular, I strongly suspect you'll get better effective test coverage of this integration by including some categorical features.

Encoding and decoding categorical features is critical to how CatBoost works (docs), and there are lots of things that have to go exactly right when providing pandas-like categorical input. Basically, everything here: https://pandas.pydata.org/docs/user_guide/categorical.html

I really think you should provide an input dataset that has some categorical features, ideally in 2 forms:

  • integer-type columns
  • pandas.categorical type columns

And ideally with varying cardinality.

You could consider adapting this code used in xgboost's tests: https://github.com/dmlc/xgboost/blob/105aa4247abb3ce787be2cef2f9beb4c24b30049/demo/guide-python/categorical.py#L29

And here are some docs on how to tell CatBoost which features are categorical: https://catboost.ai/docs/en/concepts/python-usages-examples#class-with-array-like-data-with-numerical,-categorical-and-embedding-features

return pd.DataFrame(X), pd.Series(y)


def test_catboost_regressor_with_dataframe(regression_data):
X, y = regression_data
model = CatBoostRegressor(iterations=10, verbose=0)
model.fit(X, y)
predictions = model.predict(X)
return predictions


def test_catboost_regressor_with_numpy(regression_data):
X, y = regression_data
model = CatBoostRegressor(iterations=10, verbose=0)
model.fit(X.values, y.values)
predictions = model.predict(X.values)
return predictions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry in advance, I'm not that familiar with these tests but... I'm surprised to see pytest test cases with a return statement. What is the interaction between these test cases and this a few lines up?

pytestmark = pytest.mark.assert_eq(fn=assert_catboost_equal)

Did you mean for there to be some kind of testing assertion here? Or does that custom marker somehow end up invoking that function and comparing the output of the test case with pandas inputs to its output with cudf inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion function is used to check that results from "cudf.pandas on" and "cudf.pandas off" are equal. The logic to handle that is in the conftest file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the conftest file

Thanks, I wasn't sure which conftest.py to look in. I guess this is the relevant piece, yeah?

def pytest_pyfunc_call(pyfuncitem: _pytest.python.Function):
if pyfuncitem.config.getoption("--compare"):
gold_results, cudf_results = pyfuncitem.config.stash[results]
key = get_full_nodeid(pyfuncitem)
try:
gold = gold_results[key]
except KeyError:
assert False, "pickled gold result is not available"
try:
cudf = cudf_results[key]
except KeyError:
assert False, "pickled cudf result is not available"
if gold is None and cudf is None:
raise ValueError(f"Integration test {key} did not return a value")
asserter = pyfuncitem.get_closest_marker("assert_eq")
if asserter is None:
assert gold == cudf, "Test failed"
else:
asserter.kwargs["fn"](gold, cudf)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup!



def test_catboost_classifier_with_dataframe(classification_data):
X, y = classification_data
model = CatBoostClassifier(iterations=10, verbose=0)
model.fit(X, y)
predictions = model.predict(X)
return predictions


def test_catboost_classifier_with_numpy(classification_data):
X, y = classification_data
model = CatBoostClassifier(iterations=10, verbose=0)
model.fit(X.values, y.values)
predictions = model.predict(X.values)
return predictions


def test_catboost_with_pool_and_dataframe(regression_data):
X, y = regression_data
train_pool = Pool(X, y)
model = CatBoostRegressor(iterations=10, verbose=0)
model.fit(train_pool)
predictions = model.predict(X)
return predictions


def test_catboost_with_pool_and_numpy(regression_data):
X, y = regression_data
train_pool = Pool(X.values, y.values)
model = CatBoostRegressor(iterations=10, verbose=0)
model.fit(train_pool)
predictions = model.predict(X.values)
return predictions


def test_catboost_with_categorical_features():
data = {
"numerical_feature": rng.standard_normal(100),
"categorical_feature": rng.choice(["A", "B", "C"], size=100),
"target": rng.integers(0, 2, size=100),
}
df = pd.DataFrame(data)
X = df[["numerical_feature", "categorical_feature"]]
y = df["target"]
cat_features = ["categorical_feature"]
model = CatBoostClassifier(
iterations=10, verbose=0, cat_features=cat_features
)
model.fit(X, y)
predictions = model.predict(X)
return predictions


@pytest.mark.parametrize(
"X, y",
[
(
pd.DataFrame(rng.standard_normal((100, 5))),
pd.Series(rng.standard_normal(100)),
),
(rng.standard_normal((100, 5)), rng.standard_normal(100)),
],
)
def test_catboost_train_test_split(X, y):
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
model = CatBoostRegressor(iterations=10, verbose=0)
model.fit(X_train, y_train)
predictions = model.predict(X_test)
return len(X_train), len(X_test), len(y_train), len(y_test), predictions
Loading