Skip to content

Commit

Permalink
Merge pull request #2 from transferwise/classification
Browse files Browse the repository at this point in the history
Proper tests and handling of classification
  • Loading branch information
EgorKraevTransferwise authored Sep 23, 2024
2 parents 50dcac3 + 00acb74 commit 1d0d519
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 39 deletions.
110 changes: 71 additions & 39 deletions shap_select/select.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, Tuple, List
from typing import Any, Tuple, List, Dict

import pandas as pd
import statsmodels.api as sm
import scipy.stats as stats
import shap


def create_shap_features(tree_model: Any, validation_df: pd.DataFrame) -> pd.DataFrame:
def create_shap_features(
tree_model: Any, validation_df: pd.DataFrame, classes: List | None = None
) -> pd.DataFrame | Dict[Any, pd.DataFrame]:
"""
Generates SHAP (SHapley Additive exPlanations) values for a given tree-based model on a validation dataset.
Expand All @@ -18,13 +21,26 @@ def create_shap_features(tree_model: Any, validation_df: pd.DataFrame) -> pd.Dat
- pd.DataFrame: A DataFrame containing the SHAP values for each feature in the `validation_df`, where each column
corresponds to the SHAP values of a feature, and the rows match the index of the `validation_df`.
"""
explainer = shap.TreeExplainer(tree_model, model_output="raw")
shap_values = explainer(validation_df).values

# Create a DataFrame with the SHAP values, with one column per feature
return pd.DataFrame(
shap_values, columns=validation_df.columns, index=validation_df.index
)
explainer = shap.TreeExplainer(tree_model, model_output="raw")(validation_df)
shap_values = explainer.values

if len(shap_values.shape) == 2:
assert (
classes is None
), "Don't specify classes for binary classification or regression"
# Create a DataFrame with the SHAP values, with one column per feature
return pd.DataFrame(
shap_values, columns=validation_df.columns, index=validation_df.index
)
elif len(shap_values.shape) == 3: # multiclass classification
out = {}
for i, c in enumerate(classes):
out[i] = pd.DataFrame(
shap_values[:, :, i],
columns=validation_df.columns,
index=validation_df.index,
)
return out


def binary_classifier_significance(
Expand Down Expand Up @@ -66,12 +82,14 @@ def binary_classifier_significance(
"t-value": summary_frame["Coef."] / summary_frame["Std.Err."],
}
).reset_index(drop=True)
result_df["closeness to 1.0"] = closeness_to_one(result_df).abs()
result_df["closeness to 1.0"] = (result_df["coefficient"] - 1.0).abs()
return result_df


def multi_classifier_significance(
shap_features: pd.DataFrame, target: pd.Series
shap_features: Dict[Any, pd.DataFrame],
target: pd.Series,
return_individual_significances: bool = False,
) -> (pd.DataFrame, list):
"""
Fits a binary logistic regression model for each unique class in the target, comparing each class against all others (one-vs-all).
Expand All @@ -85,25 +103,36 @@ def multi_classifier_significance(
- A DataFrame with feature names and their maximum significance values across all binary classifications.
- A list of DataFrames, one for each binary classification, containing feature names, coefficients, standard errors, and statistical significance.
"""
unique_classes = target.unique()
significance_dfs = []

# Iterate through each class and perform binary classification (one-vs-all)
for cls in unique_classes:
for cls, feature_df in shap_features.items():
binary_target = (target == cls).astype(int)
significance_df = binary_classifier_significance(shap_features, binary_target)
significance_df = binary_classifier_significance(feature_df, binary_target)
significance_dfs.append(significance_df)

# Combine results into a single DataFrame with the max significance value for each feature
combined_df = pd.concat(significance_dfs)
max_significance_df = (
combined_df.groupby("feature name", as_index=False)
.agg({"stat.significance": "min", "t-value": "max", "closeness to 1.0": "min"})
.agg(
{
"t-value": "max",
"closeness to 1.0": "min",
"coefficient": max,
}
)
.reset_index(drop=True)
)
max_significance_df.columns = ["feature name", "max significance value"]

return max_significance_df, significance_dfs
# Len(shap_features) multiplier is the Bonferroni correction
max_significance_df["stat.significance"] = max_significance_df["t-value"].apply(
lambda x: len(shap_features) * (1 - stats.norm.cdf(x))
)
if return_individual_significances:
return max_significance_df, significance_dfs
else:
return max_significance_df


def regression_significance(
Expand Down Expand Up @@ -141,7 +170,7 @@ def regression_significance(
"t-value": summary_frame["Coef."] / summary_frame["Std.Err."],
}
).reset_index(drop=True)
result_df["closeness to 1.0"] = closeness_to_one(result_df).abs()
result_df["closeness to 1.0"] = (result_df["coefficient"] - 1.0).abs()

return result_df

Expand All @@ -151,7 +180,9 @@ def closeness_to_one(df: pd.DataFrame) -> pd.Series:


def shap_features_to_significance(
shap_features: pd.DataFrame, target: pd.Series, task: str | None = None
shap_features: pd.DataFrame | List[pd.DataFrame],
target: pd.Series,
task: str,
) -> pd.DataFrame:
"""
Determines the task (regression, binary, or multi-class classification) based on the target and calls the appropriate
Expand All @@ -160,8 +191,7 @@ def shap_features_to_significance(
Parameters:
shap_features (pd.DataFrame): A DataFrame containing the features used for prediction.
target (pd.Series): The target series for prediction (either continuous or categorical).
task (str | None): The type of task to perform. If None, the function will infer the task automatically.
The options are "regression", "binary", or "multi".
task (str): The type of task to perform: "regression", "binary", or "multiclass".
Returns:
pd.DataFrame: A DataFrame containing:
Expand All @@ -170,27 +200,15 @@ def shap_features_to_significance(
Sorted in descending order of significance (ascending p-value).
"""

# Infer the task if not provided
if task is None:
if pd.api.types.is_numeric_dtype(target) and target.nunique() > 10:
task = "regression"
elif target.nunique() == 2:
task = "binary"
else:
task = "multi"

# Call the appropriate function based on the task
if task == "regression":
result_df = regression_significance(shap_features, target)
elif task == "binary":
result_df = binary_classifier_significance(shap_features, target)
elif task == "multi":
max_significance_df, _ = multi_classifier_significance(shap_features, target)
result_df = max_significance_df.rename(
columns={"max significance value": "stat.significance"}
)
elif task == "multiclass":
result_df = multi_classifier_significance(shap_features, target)
else:
raise ValueError("`task` must be 'regression', 'binary', 'multi' or None.")
raise ValueError("`task` must be 'regression', 'binary', 'multiclass' or None.")

# Sort the result by statistical significance in ascending order (more significant features first)
result_df_sorted = result_df.sort_values(by="t-value", ascending=False).reset_index(
Expand Down Expand Up @@ -227,8 +245,22 @@ def score_features(
if isinstance(target, str):
target = validation_df[target]

# Generate SHAP values for the validation dataset
shap_features = create_shap_features(tree_model, validation_df[feature_names])
# Infer the task if not provided
if task is None:
if pd.api.types.is_numeric_dtype(target) and target.nunique() > 10:
task = "regression"
elif target.nunique() == 2:
task = "binary"
else:
task = "multiclass"

if task == "multiclass":
unique_classes = sorted(list(target.unique()))
shap_features = create_shap_features(
tree_model, validation_df[feature_names], unique_classes
)
else:
shap_features = create_shap_features(tree_model, validation_df[feature_names])

# Compute statistical significance of each feature
significance_df = shap_features_to_significance(shap_features, target, task)
Expand All @@ -237,6 +269,6 @@ def score_features(
significance_df["Selected"] = (
significance_df["stat.significance"] < threshold
).astype(int)
significance_df.loc[significance_df["coefficient"] < 0, "Selected"] = -1
significance_df.loc[significance_df["t-value"] < 0, "Selected"] = -1

return significance_df, shap_features
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit 1d0d519

Please sign in to comment.