Skip to content

Commit

Permalink
add reorder_classes function
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Mar 6, 2025
1 parent ccd122c commit 9561285
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
57 changes: 57 additions & 0 deletions python/interpret-core/interpret/glassbox/_ebm/_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3084,6 +3084,63 @@ def predict(self, X, init_score=None):
# multiclass
return self.classes_[np.argmax(scores, axis=1)]

def reorder_classes(self, classes):
"""Re-order the class positions in a classification EBM.
Args:
classes: The new class order
Returns:
Itself.
"""
check_is_fitted(self, "has_fitted_")

classes = np.asarray(classes, dtype=self.classes_.dtype)

if len(classes) != len(self.classes_):
raise ValueError(
"The EBM contains {len(self.classes_)} classes, but the 'classes' parameter contains {len(classes)} items."
)

mapping = dict(zip(self.classes_, count()))
try:
mapping = np.fromiter(
map(mapping.__getitem__, classes),
np.uint64,
count=len(mapping),
)
except KeyError as e:
raise ValueError(
f"The 'classes' parameter contains a class '{e.args[0]}' not present in the EBM."
) from e

if len(mapping) != len(set(mapping)):
raise ValueError("The 'classes' parameter contains duplicates.")

self.classes_ = self.classes_[mapping]

if len(mapping) == 2:
if mapping[0] == 1:
np.negative(self.intercept_, out=self.intercept_)
np.negative(self.bagged_intercept_, out=self.bagged_intercept_)
for scores in self.bagged_scores_:
np.negative(scores, out=scores)
for scores in self.term_scores_:
np.negative(scores, out=scores)
elif 3 <= len(mapping):
self.intercept_ = self.intercept_[mapping]
self.bagged_intercept_ = self.bagged_intercept_[:, mapping]
self.bagged_scores_ = [
scores[..., mapping] for scores in self.bagged_scores_
]
self.term_scores_ = [scores[..., mapping] for scores in self.term_scores_]
self.standard_deviations_ = [
scores[..., mapping] for scores in self.standard_deviations_
]

return self

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "classifier"
Expand Down
59 changes: 59 additions & 0 deletions python/interpret-core/tests/glassbox/ebm/test_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,3 +1243,62 @@ def test_replicatability_classification():
if total1 != total2:
assert total1 == total2
break


def test_reorder_classes_binary_nochange():
X, y, names, types = make_synthetic(classes=2, output_type="float", n_samples=250)

ebm = ExplainableBoostingClassifier(names, types, max_rounds=10)
ebm.fit(X, y)

pred = ebm.predict_proba(X)
ebm.reorder_classes([0, 1])

pred_reordered = ebm.predict_proba(X)

assert np.allclose(pred, pred_reordered)


def test_reorder_classes_binary_flip():
X, y, names, types = make_synthetic(classes=2, output_type="float", n_samples=250)

ebm = ExplainableBoostingClassifier(names, types, max_rounds=10)
ebm.fit(X, y)

pred = ebm.predict_proba(X)
ebm.reorder_classes([1, 0])

pred_reordered = ebm.predict_proba(X)

assert np.allclose(pred[:, [1, 0]], pred_reordered)


def test_reorder_classes_multiclass():
X, y, names, types = make_synthetic(classes=3, output_type="float", n_samples=250)

ebm = ExplainableBoostingClassifier(names, types, max_rounds=10)
ebm.fit(X, y)

pred = ebm.predict_proba(X)
ebm.reorder_classes([1, 2, 0])

pred_reordered = ebm.predict_proba(X)

assert np.allclose(pred[:, [1, 2, 0]], pred_reordered)


def test_reorder_classes_strings():
X, y, names, types = make_synthetic(classes=3, output_type="float", n_samples=250)

mapping = {0: "cats", 1: "dogs", 2: "elephants"}
y = np.vectorize(lambda x: mapping[x])(y)

ebm = ExplainableBoostingClassifier(names, types, max_rounds=10)
ebm.fit(X, y)

pred = ebm.predict_proba(X)
ebm.reorder_classes(["dogs", "elephants", "cats"])

pred_reordered = ebm.predict_proba(X)

assert np.allclose(pred[:, [1, 2, 0]], pred_reordered)

0 comments on commit 9561285

Please sign in to comment.