-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tidy up requirements, add setup.py, prettify README.md
- Loading branch information
1 parent
925a196
commit 81ef84b
Showing
4 changed files
with
223 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,190 @@ | ||
## Overview | ||
`shap-select` implements a heuristic to do fast feature selection for tabular regression and classification models. | ||
`shap-select` implements a heuristic for fast feature selection, for tabular regression and classification models. | ||
|
||
The basic idea is running a linear or logistic regression of the target on the Shapley values on the validation set, | ||
The basic idea is running a linear or logistic regression of the target on the Shapley values of | ||
the original features, on the validation set, | ||
discarding the features with negative coefficients, and ranking/filtering the rest according to their | ||
statistical significance. For motivation and details, see the [example notebook](https://github.com/transferwise/shap-select/blob/main/docs/Quick%20feature%20selection%20through%20regression%20on%20Shapley%20values.ipynb) | ||
|
||
Earlier packages using Shapley values for feature selection exist, the advantages of this one are | ||
* Regression on the **validation set** to combat overfitting | ||
* A single pass regression, not an iterative approach | ||
* Only a single fit of the original model needed | ||
* A single intuitive hyperparameter for feature selection: statistical significance | ||
* Bonferroni correction for multiclass classification | ||
* Address collinearity of (Shapley value) features by repeated (linear/logistic) regression | ||
|
||
## Usage | ||
```python | ||
from shap_select import shap_select | ||
# Here model is any model supported by the shap library, fitted on a different (train) dataset | ||
# Task can be regression, binary, or multiclass | ||
selected_features_df = shap_select(model, X_val, y_val, task="multiclass", threshold=0.05) | ||
``` | ||
``` | ||
|
||
|
||
<style type="text/css"> | ||
#T_694ab_row0_col1, #T_694ab_row0_col4, #T_694ab_row1_col4, #T_694ab_row2_col4, #T_694ab_row3_col4, #T_694ab_row4_col4, #T_694ab_row5_col4, #T_694ab_row6_col3, #T_694ab_row7_col2 { | ||
background-color: #b40426; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row0_col2, #T_694ab_row1_col2, #T_694ab_row2_col2, #T_694ab_row3_col2, #T_694ab_row4_col2, #T_694ab_row8_col1, #T_694ab_row8_col3, #T_694ab_row8_col4 { | ||
background-color: #3b4cc0; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row0_col3, #T_694ab_row3_col3 { | ||
background-color: #f39778; | ||
color: #000000; | ||
} | ||
#T_694ab_row1_col1 { | ||
background-color: #d24b40; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row1_col3 { | ||
background-color: #f59d7e; | ||
color: #000000; | ||
} | ||
#T_694ab_row2_col1 { | ||
background-color: #bcd2f7; | ||
color: #000000; | ||
} | ||
#T_694ab_row2_col3 { | ||
background-color: #f39577; | ||
color: #000000; | ||
} | ||
#T_694ab_row3_col1 { | ||
background-color: #b6cefa; | ||
color: #000000; | ||
} | ||
#T_694ab_row4_col1 { | ||
background-color: #a7c5fe; | ||
color: #000000; | ||
} | ||
#T_694ab_row4_col3 { | ||
background-color: #f59f80; | ||
color: #000000; | ||
} | ||
#T_694ab_row5_col1 { | ||
background-color: #7597f6; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row5_col2 { | ||
background-color: #4358cb; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row5_col3 { | ||
background-color: #eb7d62; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row6_col1 { | ||
background-color: #5e7de7; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row6_col2 { | ||
background-color: #f6bfa6; | ||
color: #000000; | ||
} | ||
#T_694ab_row6_col4, #T_694ab_row7_col4 { | ||
background-color: #dddcdc; | ||
color: #000000; | ||
} | ||
#T_694ab_row7_col1 { | ||
background-color: #5977e3; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row7_col3 { | ||
background-color: #de614d; | ||
color: #f1f1f1; | ||
} | ||
#T_694ab_row8_col2 { | ||
background-color: #779af7; | ||
color: #f1f1f1; | ||
} | ||
</style> | ||
<table id="T_694ab"> | ||
<thead> | ||
<tr> | ||
<th class="blank level0" > </th> | ||
<th id="T_694ab_level0_col0" class="col_heading level0 col0" >feature name</th> | ||
<th id="T_694ab_level0_col1" class="col_heading level0 col1" >t-value</th> | ||
<th id="T_694ab_level0_col2" class="col_heading level0 col2" >stat.significance</th> | ||
<th id="T_694ab_level0_col3" class="col_heading level0 col3" >coefficient</th> | ||
<th id="T_694ab_level0_col4" class="col_heading level0 col4" >selected</th> | ||
</tr> | ||
</thead> | ||
<tbody> | ||
<tr> | ||
<th id="T_694ab_level0_row0" class="row_heading level0 row0" >0</th> | ||
<td id="T_694ab_row0_col0" class="data row0 col0" >x5</td> | ||
<td id="T_694ab_row0_col1" class="data row0 col1" >20.211299</td> | ||
<td id="T_694ab_row0_col2" class="data row0 col2" >0.000000</td> | ||
<td id="T_694ab_row0_col3" class="data row0 col3" >1.052030</td> | ||
<td id="T_694ab_row0_col4" class="data row0 col4" >1</td> | ||
</tr> | ||
<tr> | ||
<th id="T_694ab_level0_row1" class="row_heading level0 row1" >1</th> | ||
<td id="T_694ab_row1_col0" class="data row1 col0" >x4</td> | ||
<td id="T_694ab_row1_col1" class="data row1 col1" >18.315144</td> | ||
<td id="T_694ab_row1_col2" class="data row1 col2" >0.000000</td> | ||
<td id="T_694ab_row1_col3" class="data row1 col3" >0.952416</td> | ||
<td id="T_694ab_row1_col4" class="data row1 col4" >1</td> | ||
</tr> | ||
<tr> | ||
<th id="T_694ab_level0_row2" class="row_heading level0 row2" >2</th> | ||
<td id="T_694ab_row2_col0" class="data row2 col0" >x3</td> | ||
<td id="T_694ab_row2_col1" class="data row2 col1" >6.835690</td> | ||
<td id="T_694ab_row2_col2" class="data row2 col2" >0.000000</td> | ||
<td id="T_694ab_row2_col3" class="data row2 col3" >1.098154</td> | ||
<td id="T_694ab_row2_col4" class="data row2 col4" >1</td> | ||
</tr> | ||
<tr> | ||
<th id="T_694ab_level0_row3" class="row_heading level0 row3" >3</th> | ||
<td id="T_694ab_row3_col0" class="data row3 col0" >x2</td> | ||
<td id="T_694ab_row3_col1" class="data row3 col1" >6.457140</td> | ||
<td id="T_694ab_row3_col2" class="data row3 col2" >0.000000</td> | ||
<td id="T_694ab_row3_col3" class="data row3 col3" >1.044842</td> | ||
<td id="T_694ab_row3_col4" class="data row3 col4" >1</td> | ||
</tr> | ||
<tr> | ||
<th id="T_694ab_level0_row4" class="row_heading level0 row4" >4</th> | ||
<td id="T_694ab_row4_col0" class="data row4 col0" >x1</td> | ||
<td id="T_694ab_row4_col1" class="data row4 col1" >5.530556</td> | ||
<td id="T_694ab_row4_col2" class="data row4 col2" >0.000000</td> | ||
<td id="T_694ab_row4_col3" class="data row4 col3" >0.917242</td> | ||
<td id="T_694ab_row4_col4" class="data row4 col4" >1</td> | ||
</tr> | ||
<tr> | ||
<th id="T_694ab_level0_row5" class="row_heading level0 row5" >5</th> | ||
<td id="T_694ab_row5_col0" class="data row5 col0" >x6</td> | ||
<td id="T_694ab_row5_col1" class="data row5 col1" >2.390868</td> | ||
<td id="T_694ab_row5_col2" class="data row5 col2" >0.016827</td> | ||
<td id="T_694ab_row5_col3" class="data row5 col3" >1.497983</td> | ||
<td id="T_694ab_row5_col4" class="data row5 col4" >1</td> | ||
</tr> | ||
<tr> | ||
<th id="T_694ab_level0_row6" class="row_heading level0 row6" >6</th> | ||
<td id="T_694ab_row6_col0" class="data row6 col0" >x7</td> | ||
<td id="T_694ab_row6_col1" class="data row6 col1" >0.901098</td> | ||
<td id="T_694ab_row6_col2" class="data row6 col2" >0.367558</td> | ||
<td id="T_694ab_row6_col3" class="data row6 col3" >2.865508</td> | ||
<td id="T_694ab_row6_col4" class="data row6 col4" >0</td> | ||
</tr> | ||
<tr> | ||
<th id="T_694ab_level0_row7" class="row_heading level0 row7" >7</th> | ||
<td id="T_694ab_row7_col0" class="data row7 col0" >x8</td> | ||
<td id="T_694ab_row7_col1" class="data row7 col1" >0.563214</td> | ||
<td id="T_694ab_row7_col2" class="data row7 col2" >0.573302</td> | ||
<td id="T_694ab_row7_col3" class="data row7 col3" >1.933632</td> | ||
<td id="T_694ab_row7_col4" class="data row7 col4" >0</td> | ||
</tr> | ||
<tr> | ||
<th id="T_694ab_level0_row8" class="row_heading level0 row8" >8</th> | ||
<td id="T_694ab_row8_col0" class="data row8 col0" >x9</td> | ||
<td id="T_694ab_row8_col1" class="data row8 col1" >-1.607814</td> | ||
<td id="T_694ab_row8_col2" class="data row8 col2" >0.107908</td> | ||
<td id="T_694ab_row8_col3" class="data row8 col3" >-4.537098</td> | ||
<td id="T_694ab_row8_col4" class="data row8 col4" >-1</td> | ||
</tr> | ||
</tbody> | ||
</table> | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,4 @@ | ||
pandas | ||
scikit_learn | ||
scipy | ||
shap | ||
statsmodels | ||
numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from setuptools import find_packages, setup | ||
|
||
with open("README.md") as f: | ||
long_description = f.read() | ||
|
||
setup( | ||
name="shap-select", | ||
version="0.1.0", | ||
description="Heuristic for quick feature selection for tabular regression/classification using shapley values", | ||
long_description=long_description, | ||
long_description_content_type="text/markdown", | ||
author="Wise Plc", | ||
url="https://github.com/transferwise/shap-select", | ||
classifiers=[ | ||
"Programming Language :: Python :: 3 :: Only", | ||
"Programming Language :: Python :: 3.7", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
], | ||
install_requires=[ | ||
"pandas", | ||
"scipy>=1.8.0", | ||
"shap", | ||
"statsmodels", | ||
], | ||
extras_require={ | ||
"test": ["flake8", "pytest", "pytest-cov"], | ||
}, | ||
packages=find_packages( | ||
include=["shap_select", "shap_select.*"], | ||
exclude=["tests*"], | ||
), | ||
include_package_data=True, | ||
keywords="shap-select", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters