diff --git a/docs/Quick feature selection through regression on Shapley values.ipynb b/docs/Quick feature selection through regression on Shapley values.ipynb
index d441efb..40edd2d 100644
--- a/docs/Quick feature selection through regression on Shapley values.ipynb
+++ b/docs/Quick feature selection through regression on Shapley values.ipynb
@@ -27,7 +27,16 @@
"execution_count": 1,
"id": "51cd6a7d",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\EgorKraev\\miniconda3\\envs\\shap-select3.10\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
"source": [
"import os, sys\n",
"from typing import List\n",
@@ -175,7 +184,8 @@
"[65]\tvalid-rmse:14.51705\n",
"[66]\tvalid-rmse:14.52365\n",
"[67]\tvalid-rmse:14.52792\n",
- "[68]\tvalid-rmse:14.53296\n"
+ "[68]\tvalid-rmse:14.53296\n",
+ "[69]\tvalid-rmse:14.53426\n"
]
}
],
@@ -208,7 +218,15 @@
"execution_count": 4,
"id": "8f403fc5",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Condition number: 67.24977\n"
+ ]
+ }
+ ],
"source": [
"selected_features_df = shap_select(model, X_val, y_val, task=\"regression\", threshold=0.05)"
]
@@ -223,172 +241,172 @@
"data": {
"text/html": [
"\n",
- "
\n",
+ "\n",
" \n",
" \n",
" | \n",
- " feature name | \n",
- " t-value | \n",
- " stat.significance | \n",
- " coefficient | \n",
- " selected | \n",
+ " feature name | \n",
+ " t-value | \n",
+ " stat.significance | \n",
+ " coefficient | \n",
+ " selected | \n",
"
\n",
" \n",
" \n",
" \n",
- " 0 | \n",
- " x5 | \n",
- " 20.211299 | \n",
- " 0.000000 | \n",
- " 1.052030 | \n",
- " 1 | \n",
+ " 0 | \n",
+ " x5 | \n",
+ " 20.211298 | \n",
+ " 0.000000 | \n",
+ " 1.052030 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 1 | \n",
- " x4 | \n",
- " 18.315144 | \n",
- " 0.000000 | \n",
- " 0.952416 | \n",
- " 1 | \n",
+ " 1 | \n",
+ " x4 | \n",
+ " 18.315144 | \n",
+ " 0.000000 | \n",
+ " 0.952416 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 2 | \n",
- " x3 | \n",
- " 6.835690 | \n",
- " 0.000000 | \n",
- " 1.098154 | \n",
- " 1 | \n",
+ " 2 | \n",
+ " x3 | \n",
+ " 6.835690 | \n",
+ " 0.000000 | \n",
+ " 1.098154 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 3 | \n",
- " x2 | \n",
- " 6.457140 | \n",
- " 0.000000 | \n",
- " 1.044842 | \n",
- " 1 | \n",
+ " 3 | \n",
+ " x2 | \n",
+ " 6.457140 | \n",
+ " 0.000000 | \n",
+ " 1.044842 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 4 | \n",
- " x1 | \n",
- " 5.530556 | \n",
- " 0.000000 | \n",
- " 0.917242 | \n",
- " 1 | \n",
+ " 4 | \n",
+ " x1 | \n",
+ " 5.530556 | \n",
+ " 0.000000 | \n",
+ " 0.917242 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 5 | \n",
- " x6 | \n",
- " 2.390868 | \n",
- " 0.016827 | \n",
- " 1.497983 | \n",
- " 1 | \n",
+ " 5 | \n",
+ " x6 | \n",
+ " 2.390868 | \n",
+ " 0.016827 | \n",
+ " 1.497983 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 6 | \n",
- " x7 | \n",
- " 0.901098 | \n",
- " 0.367558 | \n",
- " 2.865508 | \n",
- " 0 | \n",
+ " 6 | \n",
+ " x7 | \n",
+ " 0.901098 | \n",
+ " 0.367558 | \n",
+ " 2.865508 | \n",
+ " 0 | \n",
"
\n",
" \n",
- " 7 | \n",
- " x8 | \n",
- " 0.563214 | \n",
- " 0.573302 | \n",
- " 1.933632 | \n",
- " 0 | \n",
+ " 7 | \n",
+ " x8 | \n",
+ " 0.563214 | \n",
+ " 0.573302 | \n",
+ " 1.933632 | \n",
+ " 0 | \n",
"
\n",
" \n",
- " 8 | \n",
- " x9 | \n",
- " -1.607814 | \n",
- " 0.107908 | \n",
- " -4.537098 | \n",
- " -1 | \n",
+ " 8 | \n",
+ " x9 | \n",
+ " -1.607814 | \n",
+ " 0.107908 | \n",
+ " -4.537098 | \n",
+ " -1 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 5,
@@ -1008,14 +1026,14 @@
"[564]\tvalid-mlogloss:0.03021\n",
"[565]\tvalid-mlogloss:0.03016\n",
"[566]\tvalid-mlogloss:0.03011\n",
- "[567]\tvalid-mlogloss:0.03015\n"
+ "[567]\tvalid-mlogloss:0.03015\n",
+ "[568]\tvalid-mlogloss:0.03013\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "[568]\tvalid-mlogloss:0.03013\n",
"[569]\tvalid-mlogloss:0.03016\n",
"[570]\tvalid-mlogloss:0.03014\n",
"[571]\tvalid-mlogloss:0.03011\n",
@@ -1225,178 +1243,188 @@
"cell_type": "code",
"execution_count": 7,
"id": "743d6988",
- "metadata": {},
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [
{
- "name": "stderr",
+ "name": "stdout",
"output_type": "stream",
"text": [
- "C:\\Users\\EgorKraev\\miniconda3\\envs\\llm3.11\\Lib\\site-packages\\sklearn\\svm\\_base.py:1235: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
- " warnings.warn(\n"
+ "Optimization terminated successfully.\n",
+ " Current function value: 0.028663\n",
+ " Iterations 13\n",
+ "Optimization terminated successfully.\n",
+ " Current function value: 0.066662\n",
+ " Iterations 11\n",
+ "Optimization terminated successfully.\n",
+ " Current function value: 0.001348\n",
+ " Iterations 17\n",
+ "Condition number: 77.557655\n"
]
},
{
"data": {
"text/html": [
"\n",
- "\n",
+ "\n",
" \n",
" \n",
" | \n",
- " feature name | \n",
- " t-value | \n",
- " stat.significance | \n",
- " coefficient | \n",
- " selected | \n",
+ " feature name | \n",
+ " t-value | \n",
+ " stat.significance | \n",
+ " coefficient | \n",
+ " selected | \n",
"
\n",
" \n",
" \n",
" \n",
- " 0 | \n",
- " x4 | \n",
- " 25.927565 | \n",
- " 0.000000 | \n",
- " 1.559384 | \n",
- " 1 | \n",
+ " 0 | \n",
+ " x4 | \n",
+ " 25.927565 | \n",
+ " 0.000000 | \n",
+ " 1.559384 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 1 | \n",
- " x5 | \n",
- " 25.874027 | \n",
- " 0.000000 | \n",
- " 1.571661 | \n",
- " 1 | \n",
+ " 1 | \n",
+ " x5 | \n",
+ " 25.874027 | \n",
+ " 0.000000 | \n",
+ " 1.571661 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 2 | \n",
- " x6 | \n",
- " 25.782536 | \n",
- " 0.000000 | \n",
- " 1.561214 | \n",
- " 1 | \n",
+ " 2 | \n",
+ " x6 | \n",
+ " 25.782536 | \n",
+ " 0.000000 | \n",
+ " 1.561214 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 3 | \n",
- " x2 | \n",
- " 21.367053 | \n",
- " 0.000000 | \n",
- " 1.753463 | \n",
- " 1 | \n",
+ " 3 | \n",
+ " x2 | \n",
+ " 21.367053 | \n",
+ " 0.000000 | \n",
+ " 1.753463 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 4 | \n",
- " x3 | \n",
- " 21.330803 | \n",
- " 0.000000 | \n",
- " 1.792630 | \n",
- " 1 | \n",
+ " 4 | \n",
+ " x3 | \n",
+ " 21.330803 | \n",
+ " 0.000000 | \n",
+ " 1.792630 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 5 | \n",
- " x1 | \n",
- " 12.835856 | \n",
- " 0.000000 | \n",
- " 2.197310 | \n",
- " 1 | \n",
+ " 5 | \n",
+ " x1 | \n",
+ " 12.835856 | \n",
+ " 0.000000 | \n",
+ " 2.197310 | \n",
+ " 1 | \n",
"
\n",
" \n",
- " 6 | \n",
- " x7 | \n",
- " 0.773525 | \n",
- " 0.658817 | \n",
- " 1.901079 | \n",
- " 0 | \n",
+ " 6 | \n",
+ " x7 | \n",
+ " 0.773525 | \n",
+ " 0.658817 | \n",
+ " 1.901079 | \n",
+ " 0 | \n",
"
\n",
" \n",
- " 7 | \n",
- " x9 | \n",
- " -0.206328 | \n",
- " 1.745198 | \n",
- " -0.317295 | \n",
- " -1 | \n",
+ " 7 | \n",
+ " x9 | \n",
+ " -0.206328 | \n",
+ " 1.745198 | \n",
+ " -0.317295 | \n",
+ " -1 | \n",
"
\n",
" \n",
- " 8 | \n",
- " x8 | \n",
- " -0.636902 | \n",
- " 2.213717 | \n",
- " -1.259370 | \n",
- " -1 | \n",
+ " 8 | \n",
+ " x8 | \n",
+ " -0.636902 | \n",
+ " 2.213717 | \n",
+ " -1.259370 | \n",
+ " -1 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 7,
@@ -1409,21 +1437,13 @@
"\n",
"prettify(selected_features_df, exclude=[\"feature name\"])"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "60c4d878",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python [conda env:llm3.11]",
+ "display_name": "Python [conda env:shap-select3.10]",
"language": "python",
- "name": "conda-env-llm3.11-py"
+ "name": "conda-env-shap-select3.10-py"
},
"language_info": {
"codemirror_mode": {
@@ -1435,7 +1455,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.9"
+ "version": "3.10.15"
}
},
"nbformat": 4,
diff --git a/shap_select/balance.py b/shap_select/balance.py
new file mode 100644
index 0000000..8617f5f
--- /dev/null
+++ b/shap_select/balance.py
@@ -0,0 +1,54 @@
+import pandas as pd
+from sklearn.utils import resample
+
+
+def balance_dataset(X, y):
+ """
+ Balances an unbalanced dataset by oversampling the minority class.
+
+ Parameters:
+ X (pd.DataFrame): Feature DataFrame
+ y (pd.Series): Target Series
+
+ Returns:
+ X_balanced (pd.DataFrame): Balanced features DataFrame
+ y_balanced (pd.Series): Balanced target Series
+ """
+ # Combine features and target into a single DataFrame for easier manipulation
+ df = pd.concat([X, y], axis=1)
+
+ # Identify the name of the target column
+ target_name = y.name
+
+ # Separate the majority and minority classes
+ class_counts = y.value_counts()
+ majority_class_label = class_counts.idxmax() # Label of the majority class
+ minority_class_label = class_counts.idxmin() # Label of the minority class
+
+ majority_class = df[df[target_name] == majority_class_label]
+ minority_class = df[df[target_name] == minority_class_label]
+
+ # Calculate how many samples to add to balance the dataset
+ n_majority = majority_class.shape[0]
+ n_minority = minority_class.shape[0]
+ n_to_add = n_majority - n_minority
+
+ # Upsample the minority class (i.e., duplicate samples with replacement)
+ minority_upsampled = resample(
+ minority_class,
+ replace=True, # Sample with replacement
+ n_samples=n_to_add, # How many samples to add
+ random_state=42,
+ ) # Seed for reproducibility
+
+ # Combine the majority class with the upsampled minority class
+ df_balanced = pd.concat([majority_class, minority_class, minority_upsampled])
+
+ # Shuffle the dataset to mix the new minority samples with the majority class
+ df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)
+
+ # Separate the features and target from the balanced DataFrame
+ X_balanced = df_balanced.drop(columns=target_name)
+ y_balanced = df_balanced[target_name]
+
+ return X_balanced, y_balanced
diff --git a/shap_select/select.py b/shap_select/select.py
index be092dd..4fe82a6 100644
--- a/shap_select/select.py
+++ b/shap_select/select.py
@@ -1,10 +1,13 @@
from typing import Any, Tuple, List, Dict
+import numpy as np
import pandas as pd
import statsmodels.api as sm
import scipy.stats as stats
import shap
+from shap_select.balance import balance_dataset
+
def create_shap_features(
tree_model: Any, validation_df: pd.DataFrame, classes: List | None = None
@@ -44,7 +47,7 @@ def create_shap_features(
def binary_classifier_significance(
- shap_features: pd.DataFrame, target: pd.Series, alpha: float
+ shap_features: pd.DataFrame, target: pd.Series, alpha: float, balance_ds=False
) -> pd.DataFrame:
"""
Fits a logistic regression model using the features from `shap_features` to predict the binary `target`.
@@ -61,13 +64,19 @@ def binary_classifier_significance(
- stderr: The standard error for each coefficient.
- stat.significance: The p-value (statistical significance) for each feature.
"""
+ if balance_ds:
+ shap_features, target = balance_dataset(shap_features, target)
# Add a constant to the features for the intercept in logistic regression
shap_features_with_constant = sm.add_constant(shap_features)
# Fit the logistic regression model that will generate confidence intervals
logit_model = sm.Logit(target, shap_features_with_constant)
- result = logit_model.fit_regularized(disp=False, alpha=alpha)
+ cond = np.linalg.cond(shap_features.values)
+ if cond > 1e3:
+ result = logit_model.fit_regularized(disp=False, alpha=alpha)
+ else:
+ result = logit_model.fit()
# Extract the results
summary_frame = result.summary2().tables[1]
@@ -158,7 +167,11 @@ def regression_significance(
"""
# Fit the linear regression model that will generate confidence intervals
ols_model = sm.OLS(target, shap_features)
- result = ols_model.fit_regularized(alpha=alpha, refit=True)
+ cond = np.linalg.cond(shap_features.values)
+ if cond > 1e3:
+ result = ols_model.fit_regularized(alpha=alpha, refit=True)
+ else:
+ result = ols_model.fit()
# Extract the results
summary_frame = result.summary2().tables[1]
@@ -227,6 +240,7 @@ def iterative_shap_feature_reduction(
target: pd.Series,
task: str,
alpha: float = 1e-6,
+ cond_threshold: float = 1e6,
) -> pd.DataFrame:
collected_rows = [] # List to store the rows we collect during each iteration
@@ -245,21 +259,37 @@ def iterative_shap_feature_reduction(
# Drop the feature corresponding to the lowest t-value from shap_features
feature_to_remove = min_t_value_row["feature name"]
+ cond = 1.0
if isinstance(shap_features, pd.DataFrame):
shap_features = shap_features.drop(columns=[feature_to_remove])
+ # Check for conditioning number
features_left = len(shap_features.columns)
+ if features_left:
+ cond = np.linalg.cond(shap_features.values)
else:
shap_features = {
k: v.drop(columns=[feature_to_remove]) for k, v in shap_features.items()
}
+
features_left = len(list(shap_features.values())[0].columns)
+ if features_left:
+ conds = {k: np.linalg.cond(v.values) for k, v in shap_features.items()}
+ cond = max(conds.values())
+
+ print("Condition number:", cond)
+ if (
+ cond < cond_threshold
+ ): # matrix is well-conditioned, don't need to remove more features
+ break
+
+ if features_left:
+ # The latest row is still contained in significance_df
+ collected_rows.pop()
# Convert collected rows back to a dataframe
- result_df = (
- pd.DataFrame(collected_rows)
- .sort_values(by="t-value", ascending=False)
- .reset_index()
- )
+ all_rows = pd.concat([pd.DataFrame(collected_rows), significance_df], axis=0)
+
+ result_df = all_rows.sort_values(by="t-value", ascending=False).reset_index()
return result_df