Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly call statsmodels.fit_regularized() for regularization #8

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ Earlier packages using Shapley values for feature selection exist, the advantage
```python
from shap_select import shap_select
# Here model is any model supported by the shap library, fitted on a different (train) dataset
# Task can be regression, binary, or multiclass
selected_features_df = shap_select(model, X_val, y_val, task="multiclass", threshold=0.05)
```
Empty file added docs/bug.py
Empty file.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ scikit_learn
scipy
shap
statsmodels
numpy
42 changes: 6 additions & 36 deletions shap_select/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,11 @@ def binary_classifier_significance(
"""

# Add a constant to the features for the intercept in logistic regression

# Standardizing the features (Logistic regression with L1 regularization tends to
# work better with standardized data)
shap_features_scaled = pd.DataFrame(
data=StandardScaler().fit_transform(shap_features),
columns=shap_features.columns,
)
shap_features_with_const = sm.add_constant(shap_features_scaled)

# To avoid linear dependence of features, first do a pass with tiny L1-reg
# and throw away the zero coeffs
# Define the Logistic Regression model with L1 regularization
logistic_l1 = LogisticRegression(
bkoseoglu marked this conversation as resolved.
Show resolved Hide resolved
penalty="l1", solver="liblinear", fit_intercept=False, C=1e6
) # C is the inverse of regularization strength
logistic_l1.fit(shap_features_with_const, target)

# Get the coefficients from the Logistic Regression model
# Logistic regression gives an array of shape (1, n_features), so we take [0]
coefficients = logistic_l1.coef_[0]
shap_features_filtered = sm.add_constant(shap_features).loc[
:, np.abs(coefficients) > 1e-6
]
shap_features_with_constant = sm.add_constant(shap_features)

# Fit the logistic regression model that will generate confidence intervals
logit_model = sm.Logit(target, shap_features_filtered)
result = logit_model.fit(disp=False)
logit_model = sm.Logit(target, shap_features_with_constant)
result = logit_model.fit_regularized(disp=False, alpha=1e-6)

# Extract the results
summary_frame = result.summary2().tables[1]
Expand Down Expand Up @@ -178,17 +156,9 @@ def regression_significance(
- stderr: The standard error for each coefficient.
- stat.significance: The p-value (statistical significance) for each feature.
"""

# To avoid collinearity of features, first do a pass with tiny L1-reg
# and throw away the zero coeffs
shap_features_scaled = StandardScaler().fit_transform(shap_features)
coefficients = Lasso(alpha=1e-6).fit(shap_features_scaled, target).coef_
shap_features_filtered = shap_features.loc[:, np.abs(coefficients) > 1e-6]

# Sadly regularized models tend to not produce confidence intervals, so
# Fit the linear regression model that will generate confidence intervals
ols_model = sm.OLS(target, shap_features_filtered)
result = ols_model.fit()
ols_model = sm.OLS(target, shap_features)
result = ols_model.fit_regularized(alpha=1e-6, refit=True)

# Extract the results
summary_frame = result.summary2().tables[1]
Expand Down Expand Up @@ -268,7 +238,7 @@ def shap_select(
- validation_df (pd.DataFrame): Validation dataset containing the features.
- feature_names (List[str]): A list of feature names used by the model.
- target (pd.Series | str): The target values, or the name of the target column in `validation_df`.
- task (str | None): The task type ('regression', 'binary', or 'multi'). If None, it is inferred automatically.
- task (str | None): The task type ('regression', 'binary', or 'multiclass'). If None, it is inferred automatically.
- threshold (float): Significance threshold to select features. Default is 0.05.
- return_extended_data (bool): Whether to also return the shapley values dataframe(s) and some extra columns

Expand Down
Loading