Skip to content

Commit

Permalink
Merge pull request #8 from transferwise/regularization_sm
Browse files Browse the repository at this point in the history
Properly call statsmodels.fit_regularized() for regularization
  • Loading branch information
EgorKraevTransferwise authored Sep 27, 2024
2 parents b2b7af5 + 81e5e0a commit c0673fa
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 36 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ Earlier packages using Shapley values for feature selection exist, the advantage
```python
from shap_select import shap_select
# Here model is any model supported by the shap library, fitted on a different (train) dataset
# Task can be regression, binary, or multiclass
selected_features_df = shap_select(model, X_val, y_val, task="multiclass", threshold=0.05)
```
Empty file added docs/bug.py
Empty file.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ scikit_learn
scipy
shap
statsmodels
numpy
42 changes: 6 additions & 36 deletions shap_select/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,11 @@ def binary_classifier_significance(
"""

# Add a constant to the features for the intercept in logistic regression

# Standardizing the features (Logistic regression with L1 regularization tends to
# work better with standardized data)
shap_features_scaled = pd.DataFrame(
data=StandardScaler().fit_transform(shap_features),
columns=shap_features.columns,
)
shap_features_with_const = sm.add_constant(shap_features_scaled)

# To avoid linear dependence of features, first do a pass with tiny L1-reg
# and throw away the zero coeffs
# Define the Logistic Regression model with L1 regularization
logistic_l1 = LogisticRegression(
penalty="l1", solver="liblinear", fit_intercept=False, C=1e6
) # C is the inverse of regularization strength
logistic_l1.fit(shap_features_with_const, target)

# Get the coefficients from the Logistic Regression model
# Logistic regression gives an array of shape (1, n_features), so we take [0]
coefficients = logistic_l1.coef_[0]
shap_features_filtered = sm.add_constant(shap_features).loc[
:, np.abs(coefficients) > 1e-6
]
shap_features_with_constant = sm.add_constant(shap_features)

# Fit the logistic regression model that will generate confidence intervals
logit_model = sm.Logit(target, shap_features_filtered)
result = logit_model.fit(disp=False)
logit_model = sm.Logit(target, shap_features_with_constant)
result = logit_model.fit_regularized(disp=False, alpha=1e-6)

# Extract the results
summary_frame = result.summary2().tables[1]
Expand Down Expand Up @@ -178,17 +156,9 @@ def regression_significance(
- stderr: The standard error for each coefficient.
- stat.significance: The p-value (statistical significance) for each feature.
"""

# To avoid collinearity of features, first do a pass with tiny L1-reg
# and throw away the zero coeffs
shap_features_scaled = StandardScaler().fit_transform(shap_features)
coefficients = Lasso(alpha=1e-6).fit(shap_features_scaled, target).coef_
shap_features_filtered = shap_features.loc[:, np.abs(coefficients) > 1e-6]

# Sadly regularized models tend to not produce confidence intervals, so
# Fit the linear regression model that will generate confidence intervals
ols_model = sm.OLS(target, shap_features_filtered)
result = ols_model.fit()
ols_model = sm.OLS(target, shap_features)
result = ols_model.fit_regularized(alpha=1e-6, refit=True)

# Extract the results
summary_frame = result.summary2().tables[1]
Expand Down Expand Up @@ -268,7 +238,7 @@ def shap_select(
- validation_df (pd.DataFrame): Validation dataset containing the features.
- feature_names (List[str]): A list of feature names used by the model.
- target (pd.Series | str): The target values, or the name of the target column in `validation_df`.
- task (str | None): The task type ('regression', 'binary', or 'multi'). If None, it is inferred automatically.
- task (str | None): The task type ('regression', 'binary', or 'multiclass'). If None, it is inferred automatically.
- threshold (float): Significance threshold to select features. Default is 0.05.
- return_extended_data (bool): Whether to also return the shapley values dataframe(s) and some extra columns
Expand Down

0 comments on commit c0673fa

Please sign in to comment.