diff --git a/docs/Quick feature selection through regression on Shapley values.ipynb b/docs/Quick feature selection through regression on Shapley values.ipynb index d441efb..40edd2d 100644 --- a/docs/Quick feature selection through regression on Shapley values.ipynb +++ b/docs/Quick feature selection through regression on Shapley values.ipynb @@ -27,7 +27,16 @@ "execution_count": 1, "id": "51cd6a7d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\EgorKraev\\miniconda3\\envs\\shap-select3.10\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import os, sys\n", "from typing import List\n", @@ -175,7 +184,8 @@ "[65]\tvalid-rmse:14.51705\n", "[66]\tvalid-rmse:14.52365\n", "[67]\tvalid-rmse:14.52792\n", - "[68]\tvalid-rmse:14.53296\n" + "[68]\tvalid-rmse:14.53296\n", + "[69]\tvalid-rmse:14.53426\n" ] } ], @@ -208,7 +218,15 @@ "execution_count": 4, "id": "8f403fc5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Condition number: 67.24977\n" + ] + } + ], "source": [ "selected_features_df = shap_select(model, X_val, y_val, task=\"regression\", threshold=0.05)" ] @@ -223,172 +241,172 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 feature namet-valuestat.significancecoefficientselectedfeature namet-valuestat.significancecoefficientselected
0x520.2112990.0000001.05203010x520.2112980.0000001.0520301
1x418.3151440.0000000.95241611x418.3151440.0000000.9524161
2x36.8356900.0000001.09815412x36.8356900.0000001.0981541
3x26.4571400.0000001.04484213x26.4571400.0000001.0448421
4x15.5305560.0000000.91724214x15.5305560.0000000.9172421
5x62.3908680.0168271.49798315x62.3908680.0168271.4979831
6x70.9010980.3675582.86550806x70.9010980.3675582.8655080
7x80.5632140.5733021.93363207x80.5632140.5733021.9336320
8x9-1.6078140.107908-4.537098-18x9-1.6078140.107908-4.537098-1
\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -1008,14 +1026,14 @@ "[564]\tvalid-mlogloss:0.03021\n", "[565]\tvalid-mlogloss:0.03016\n", "[566]\tvalid-mlogloss:0.03011\n", - "[567]\tvalid-mlogloss:0.03015\n" + "[567]\tvalid-mlogloss:0.03015\n", + "[568]\tvalid-mlogloss:0.03013\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "[568]\tvalid-mlogloss:0.03013\n", "[569]\tvalid-mlogloss:0.03016\n", "[570]\tvalid-mlogloss:0.03014\n", "[571]\tvalid-mlogloss:0.03011\n", @@ -1225,178 +1243,188 @@ "cell_type": "code", "execution_count": 7, "id": "743d6988", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "C:\\Users\\EgorKraev\\miniconda3\\envs\\llm3.11\\Lib\\site-packages\\sklearn\\svm\\_base.py:1235: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n", - " warnings.warn(\n" + "Optimization terminated successfully.\n", + " Current function value: 0.028663\n", + " Iterations 13\n", + "Optimization terminated successfully.\n", + " Current function value: 0.066662\n", + " Iterations 11\n", + "Optimization terminated successfully.\n", + " Current function value: 0.001348\n", + " Iterations 17\n", + "Condition number: 77.557655\n" ] }, { "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 feature namet-valuestat.significancecoefficientselectedfeature namet-valuestat.significancecoefficientselected
0x425.9275650.0000001.55938410x425.9275650.0000001.5593841
1x525.8740270.0000001.57166111x525.8740270.0000001.5716611
2x625.7825360.0000001.56121412x625.7825360.0000001.5612141
3x221.3670530.0000001.75346313x221.3670530.0000001.7534631
4x321.3308030.0000001.79263014x321.3308030.0000001.7926301
5x112.8358560.0000002.19731015x112.8358560.0000002.1973101
6x70.7735250.6588171.90107906x70.7735250.6588171.9010790
7x9-0.2063281.745198-0.317295-17x9-0.2063281.745198-0.317295-1
8x8-0.6369022.213717-1.259370-18x8-0.6369022.213717-1.259370-1
\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -1409,21 +1437,13 @@ "\n", "prettify(selected_features_df, exclude=[\"feature name\"])" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60c4d878", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:llm3.11]", + "display_name": "Python [conda env:shap-select3.10]", "language": "python", - "name": "conda-env-llm3.11-py" + "name": "conda-env-shap-select3.10-py" }, "language_info": { "codemirror_mode": { @@ -1435,7 +1455,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/shap_select/balance.py b/shap_select/balance.py new file mode 100644 index 0000000..8617f5f --- /dev/null +++ b/shap_select/balance.py @@ -0,0 +1,54 @@ +import pandas as pd +from sklearn.utils import resample + + +def balance_dataset(X, y): + """ + Balances an unbalanced dataset by oversampling the minority class. + + Parameters: + X (pd.DataFrame): Feature DataFrame + y (pd.Series): Target Series + + Returns: + X_balanced (pd.DataFrame): Balanced features DataFrame + y_balanced (pd.Series): Balanced target Series + """ + # Combine features and target into a single DataFrame for easier manipulation + df = pd.concat([X, y], axis=1) + + # Identify the name of the target column + target_name = y.name + + # Separate the majority and minority classes + class_counts = y.value_counts() + majority_class_label = class_counts.idxmax() # Label of the majority class + minority_class_label = class_counts.idxmin() # Label of the minority class + + majority_class = df[df[target_name] == majority_class_label] + minority_class = df[df[target_name] == minority_class_label] + + # Calculate how many samples to add to balance the dataset + n_majority = majority_class.shape[0] + n_minority = minority_class.shape[0] + n_to_add = n_majority - n_minority + + # Upsample the minority class (i.e., duplicate samples with replacement) + minority_upsampled = resample( + minority_class, + replace=True, # Sample with replacement + n_samples=n_to_add, # How many samples to add + random_state=42, + ) # Seed for reproducibility + + # Combine the majority class with the upsampled minority class + df_balanced = pd.concat([majority_class, minority_class, minority_upsampled]) + + # Shuffle the dataset to mix the new minority samples with the majority class + df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True) + + # Separate the features and target from the balanced DataFrame + X_balanced = df_balanced.drop(columns=target_name) + y_balanced = df_balanced[target_name] + + return X_balanced, y_balanced diff --git a/shap_select/select.py b/shap_select/select.py index be092dd..4fe82a6 100644 --- a/shap_select/select.py +++ b/shap_select/select.py @@ -1,10 +1,13 @@ from typing import Any, Tuple, List, Dict +import numpy as np import pandas as pd import statsmodels.api as sm import scipy.stats as stats import shap +from shap_select.balance import balance_dataset + def create_shap_features( tree_model: Any, validation_df: pd.DataFrame, classes: List | None = None @@ -44,7 +47,7 @@ def create_shap_features( def binary_classifier_significance( - shap_features: pd.DataFrame, target: pd.Series, alpha: float + shap_features: pd.DataFrame, target: pd.Series, alpha: float, balance_ds=False ) -> pd.DataFrame: """ Fits a logistic regression model using the features from `shap_features` to predict the binary `target`. @@ -61,13 +64,19 @@ def binary_classifier_significance( - stderr: The standard error for each coefficient. - stat.significance: The p-value (statistical significance) for each feature. """ + if balance_ds: + shap_features, target = balance_dataset(shap_features, target) # Add a constant to the features for the intercept in logistic regression shap_features_with_constant = sm.add_constant(shap_features) # Fit the logistic regression model that will generate confidence intervals logit_model = sm.Logit(target, shap_features_with_constant) - result = logit_model.fit_regularized(disp=False, alpha=alpha) + cond = np.linalg.cond(shap_features.values) + if cond > 1e3: + result = logit_model.fit_regularized(disp=False, alpha=alpha) + else: + result = logit_model.fit() # Extract the results summary_frame = result.summary2().tables[1] @@ -158,7 +167,11 @@ def regression_significance( """ # Fit the linear regression model that will generate confidence intervals ols_model = sm.OLS(target, shap_features) - result = ols_model.fit_regularized(alpha=alpha, refit=True) + cond = np.linalg.cond(shap_features.values) + if cond > 1e3: + result = ols_model.fit_regularized(alpha=alpha, refit=True) + else: + result = ols_model.fit() # Extract the results summary_frame = result.summary2().tables[1] @@ -227,6 +240,7 @@ def iterative_shap_feature_reduction( target: pd.Series, task: str, alpha: float = 1e-6, + cond_threshold: float = 1e6, ) -> pd.DataFrame: collected_rows = [] # List to store the rows we collect during each iteration @@ -245,21 +259,37 @@ def iterative_shap_feature_reduction( # Drop the feature corresponding to the lowest t-value from shap_features feature_to_remove = min_t_value_row["feature name"] + cond = 1.0 if isinstance(shap_features, pd.DataFrame): shap_features = shap_features.drop(columns=[feature_to_remove]) + # Check for conditioning number features_left = len(shap_features.columns) + if features_left: + cond = np.linalg.cond(shap_features.values) else: shap_features = { k: v.drop(columns=[feature_to_remove]) for k, v in shap_features.items() } + features_left = len(list(shap_features.values())[0].columns) + if features_left: + conds = {k: np.linalg.cond(v.values) for k, v in shap_features.items()} + cond = max(conds.values()) + + print("Condition number:", cond) + if ( + cond < cond_threshold + ): # matrix is well-conditioned, don't need to remove more features + break + + if features_left: + # The latest row is still contained in significance_df + collected_rows.pop() # Convert collected rows back to a dataframe - result_df = ( - pd.DataFrame(collected_rows) - .sort_values(by="t-value", ascending=False) - .reset_index() - ) + all_rows = pd.concat([pd.DataFrame(collected_rows), significance_df], axis=0) + + result_df = all_rows.sort_values(by="t-value", ascending=False).reset_index() return result_df