Skip to content

Commit

Permalink
tidy up requirements, add setup.py, prettify README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorKraevTransferwise committed Oct 3, 2024
1 parent 925a196 commit 81ef84b
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 13 deletions.
179 changes: 175 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,190 @@
## Overview
`shap-select` implements a heuristic to do fast feature selection for tabular regression and classification models.
`shap-select` implements a heuristic for fast feature selection, for tabular regression and classification models.

The basic idea is running a linear or logistic regression of the target on the Shapley values on the validation set,
The basic idea is running a linear or logistic regression of the target on the Shapley values of
the original features, on the validation set,
discarding the features with negative coefficients, and ranking/filtering the rest according to their
statistical significance. For motivation and details, see the [example notebook](https://github.com/transferwise/shap-select/blob/main/docs/Quick%20feature%20selection%20through%20regression%20on%20Shapley%20values.ipynb)

Earlier packages using Shapley values for feature selection exist, the advantages of this one are
* Regression on the **validation set** to combat overfitting
* A single pass regression, not an iterative approach
* Only a single fit of the original model needed
* A single intuitive hyperparameter for feature selection: statistical significance
* Bonferroni correction for multiclass classification
* Address collinearity of (Shapley value) features by repeated (linear/logistic) regression

## Usage
```python
from shap_select import shap_select
# Here model is any model supported by the shap library, fitted on a different (train) dataset
# Task can be regression, binary, or multiclass
selected_features_df = shap_select(model, X_val, y_val, task="multiclass", threshold=0.05)
```
```


<style type="text/css">
#T_694ab_row0_col1, #T_694ab_row0_col4, #T_694ab_row1_col4, #T_694ab_row2_col4, #T_694ab_row3_col4, #T_694ab_row4_col4, #T_694ab_row5_col4, #T_694ab_row6_col3, #T_694ab_row7_col2 {
background-color: #b40426;
color: #f1f1f1;
}
#T_694ab_row0_col2, #T_694ab_row1_col2, #T_694ab_row2_col2, #T_694ab_row3_col2, #T_694ab_row4_col2, #T_694ab_row8_col1, #T_694ab_row8_col3, #T_694ab_row8_col4 {
background-color: #3b4cc0;
color: #f1f1f1;
}
#T_694ab_row0_col3, #T_694ab_row3_col3 {
background-color: #f39778;
color: #000000;
}
#T_694ab_row1_col1 {
background-color: #d24b40;
color: #f1f1f1;
}
#T_694ab_row1_col3 {
background-color: #f59d7e;
color: #000000;
}
#T_694ab_row2_col1 {
background-color: #bcd2f7;
color: #000000;
}
#T_694ab_row2_col3 {
background-color: #f39577;
color: #000000;
}
#T_694ab_row3_col1 {
background-color: #b6cefa;
color: #000000;
}
#T_694ab_row4_col1 {
background-color: #a7c5fe;
color: #000000;
}
#T_694ab_row4_col3 {
background-color: #f59f80;
color: #000000;
}
#T_694ab_row5_col1 {
background-color: #7597f6;
color: #f1f1f1;
}
#T_694ab_row5_col2 {
background-color: #4358cb;
color: #f1f1f1;
}
#T_694ab_row5_col3 {
background-color: #eb7d62;
color: #f1f1f1;
}
#T_694ab_row6_col1 {
background-color: #5e7de7;
color: #f1f1f1;
}
#T_694ab_row6_col2 {
background-color: #f6bfa6;
color: #000000;
}
#T_694ab_row6_col4, #T_694ab_row7_col4 {
background-color: #dddcdc;
color: #000000;
}
#T_694ab_row7_col1 {
background-color: #5977e3;
color: #f1f1f1;
}
#T_694ab_row7_col3 {
background-color: #de614d;
color: #f1f1f1;
}
#T_694ab_row8_col2 {
background-color: #779af7;
color: #f1f1f1;
}
</style>
<table id="T_694ab">
<thead>
<tr>
<th class="blank level0" >&nbsp;</th>
<th id="T_694ab_level0_col0" class="col_heading level0 col0" >feature name</th>
<th id="T_694ab_level0_col1" class="col_heading level0 col1" >t-value</th>
<th id="T_694ab_level0_col2" class="col_heading level0 col2" >stat.significance</th>
<th id="T_694ab_level0_col3" class="col_heading level0 col3" >coefficient</th>
<th id="T_694ab_level0_col4" class="col_heading level0 col4" >selected</th>
</tr>
</thead>
<tbody>
<tr>
<th id="T_694ab_level0_row0" class="row_heading level0 row0" >0</th>
<td id="T_694ab_row0_col0" class="data row0 col0" >x5</td>
<td id="T_694ab_row0_col1" class="data row0 col1" >20.211299</td>
<td id="T_694ab_row0_col2" class="data row0 col2" >0.000000</td>
<td id="T_694ab_row0_col3" class="data row0 col3" >1.052030</td>
<td id="T_694ab_row0_col4" class="data row0 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row1" class="row_heading level0 row1" >1</th>
<td id="T_694ab_row1_col0" class="data row1 col0" >x4</td>
<td id="T_694ab_row1_col1" class="data row1 col1" >18.315144</td>
<td id="T_694ab_row1_col2" class="data row1 col2" >0.000000</td>
<td id="T_694ab_row1_col3" class="data row1 col3" >0.952416</td>
<td id="T_694ab_row1_col4" class="data row1 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row2" class="row_heading level0 row2" >2</th>
<td id="T_694ab_row2_col0" class="data row2 col0" >x3</td>
<td id="T_694ab_row2_col1" class="data row2 col1" >6.835690</td>
<td id="T_694ab_row2_col2" class="data row2 col2" >0.000000</td>
<td id="T_694ab_row2_col3" class="data row2 col3" >1.098154</td>
<td id="T_694ab_row2_col4" class="data row2 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row3" class="row_heading level0 row3" >3</th>
<td id="T_694ab_row3_col0" class="data row3 col0" >x2</td>
<td id="T_694ab_row3_col1" class="data row3 col1" >6.457140</td>
<td id="T_694ab_row3_col2" class="data row3 col2" >0.000000</td>
<td id="T_694ab_row3_col3" class="data row3 col3" >1.044842</td>
<td id="T_694ab_row3_col4" class="data row3 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row4" class="row_heading level0 row4" >4</th>
<td id="T_694ab_row4_col0" class="data row4 col0" >x1</td>
<td id="T_694ab_row4_col1" class="data row4 col1" >5.530556</td>
<td id="T_694ab_row4_col2" class="data row4 col2" >0.000000</td>
<td id="T_694ab_row4_col3" class="data row4 col3" >0.917242</td>
<td id="T_694ab_row4_col4" class="data row4 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row5" class="row_heading level0 row5" >5</th>
<td id="T_694ab_row5_col0" class="data row5 col0" >x6</td>
<td id="T_694ab_row5_col1" class="data row5 col1" >2.390868</td>
<td id="T_694ab_row5_col2" class="data row5 col2" >0.016827</td>
<td id="T_694ab_row5_col3" class="data row5 col3" >1.497983</td>
<td id="T_694ab_row5_col4" class="data row5 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row6" class="row_heading level0 row6" >6</th>
<td id="T_694ab_row6_col0" class="data row6 col0" >x7</td>
<td id="T_694ab_row6_col1" class="data row6 col1" >0.901098</td>
<td id="T_694ab_row6_col2" class="data row6 col2" >0.367558</td>
<td id="T_694ab_row6_col3" class="data row6 col3" >2.865508</td>
<td id="T_694ab_row6_col4" class="data row6 col4" >0</td>
</tr>
<tr>
<th id="T_694ab_level0_row7" class="row_heading level0 row7" >7</th>
<td id="T_694ab_row7_col0" class="data row7 col0" >x8</td>
<td id="T_694ab_row7_col1" class="data row7 col1" >0.563214</td>
<td id="T_694ab_row7_col2" class="data row7 col2" >0.573302</td>
<td id="T_694ab_row7_col3" class="data row7 col3" >1.933632</td>
<td id="T_694ab_row7_col4" class="data row7 col4" >0</td>
</tr>
<tr>
<th id="T_694ab_level0_row8" class="row_heading level0 row8" >8</th>
<td id="T_694ab_row8_col0" class="data row8 col0" >x9</td>
<td id="T_694ab_row8_col1" class="data row8 col1" >-1.607814</td>
<td id="T_694ab_row8_col2" class="data row8 col2" >0.107908</td>
<td id="T_694ab_row8_col3" class="data row8 col3" >-4.537098</td>
<td id="T_694ab_row8_col4" class="data row8 col4" >-1</td>
</tr>
</tbody>
</table>


2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
pandas
scikit_learn
scipy
shap
statsmodels
numpy
38 changes: 38 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from setuptools import find_packages, setup

with open("README.md") as f:
long_description = f.read()

setup(
name="shap-select",
version="0.1.0",
description="Heuristic for quick feature selection for tabular regression/classification using shapley values",
long_description=long_description,
long_description_content_type="text/markdown",
author="Wise Plc",
url="https://github.com/transferwise/shap-select",
classifiers=[
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
install_requires=[
"pandas",
"scipy>=1.8.0",
"shap",
"statsmodels",
],
extras_require={
"test": ["flake8", "pytest", "pytest-cov"],
},
packages=find_packages(
include=["shap_select", "shap_select.*"],
exclude=["tests*"],
),
include_package_data=True,
keywords="shap-select",
)
17 changes: 10 additions & 7 deletions shap_select/select.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Any, Tuple, List, Dict

import pandas as pd
import numpy as np
import statsmodels.api as sm
from sklearn.linear_model import Lasso, LogisticRegression
from sklearn.preprocessing import StandardScaler
import scipy.stats as stats
import shap

Expand Down Expand Up @@ -112,7 +109,9 @@ def multi_classifier_significance(
# Iterate through each class and perform binary classification (one-vs-all)
for cls, feature_df in shap_features.items():
binary_target = (target == cls).astype(int)
significance_df = binary_classifier_significance(feature_df, binary_target, alpha)
significance_df = binary_classifier_significance(
feature_df, binary_target, alpha
)
significance_dfs.append(significance_df)

# Combine results into a single DataFrame with the max significance value for each feature
Expand Down Expand Up @@ -227,14 +226,16 @@ def iterative_shap_feature_reduction(
shap_features: pd.DataFrame | List[pd.DataFrame],
target: pd.Series,
task: str,
alpha: float=1e-6,
alpha: float = 1e-6,
) -> pd.DataFrame:
collected_rows = [] # List to store the rows we collect during each iteration

features_left = True
while features_left:
# Call the original shap_features_to_significance function
significance_df = shap_features_to_significance(shap_features, target, task, alpha)
significance_df = shap_features_to_significance(
shap_features, target, task, alpha
)

# Find the feature with the lowest t-value
min_t_value_row = significance_df.loc[significance_df["t-value"].idxmin()]
Expand Down Expand Up @@ -315,7 +316,9 @@ def shap_select(
shap_features = create_shap_features(tree_model, validation_df[feature_names])

# Compute statistical significance of each feature, recursively ablating
significance_df = iterative_shap_feature_reduction(shap_features, target, task, alpha)
significance_df = iterative_shap_feature_reduction(
shap_features, target, task, alpha
)

# Add 'Selected' column based on the threshold
significance_df["selected"] = (
Expand Down

0 comments on commit 81ef84b

Please sign in to comment.