Skip to content

Commit

Permalink
Merge pull request #10 from transferwise/parameterize_alpha
Browse files Browse the repository at this point in the history
refactor alpha parameter for regression
  • Loading branch information
EgorKraevTransferwise authored Sep 30, 2024
2 parents c0673fa + eee29f3 commit 29e0847
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions shap_select/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create_shap_features(


def binary_classifier_significance(
shap_features: pd.DataFrame, target: pd.Series
shap_features: pd.DataFrame, target: pd.Series, alpha: float
) -> pd.DataFrame:
"""
Fits a logistic regression model using the features from `shap_features` to predict the binary `target`.
Expand All @@ -70,7 +70,7 @@ def binary_classifier_significance(

# Fit the logistic regression model that will generate confidence intervals
logit_model = sm.Logit(target, shap_features_with_constant)
result = logit_model.fit_regularized(disp=False, alpha=1e-6)
result = logit_model.fit_regularized(disp=False, alpha=alpha)

# Extract the results
summary_frame = result.summary2().tables[1]
Expand All @@ -92,6 +92,7 @@ def binary_classifier_significance(
def multi_classifier_significance(
shap_features: Dict[Any, pd.DataFrame],
target: pd.Series,
alpha: float,
return_individual_significances: bool = False,
) -> (pd.DataFrame, list):
"""
Expand All @@ -111,7 +112,7 @@ def multi_classifier_significance(
# Iterate through each class and perform binary classification (one-vs-all)
for cls, feature_df in shap_features.items():
binary_target = (target == cls).astype(int)
significance_df = binary_classifier_significance(feature_df, binary_target)
significance_df = binary_classifier_significance(feature_df, binary_target, alpha)
significance_dfs.append(significance_df)

# Combine results into a single DataFrame with the max significance value for each feature
Expand Down Expand Up @@ -139,7 +140,7 @@ def multi_classifier_significance(


def regression_significance(
shap_features: pd.DataFrame, target: pd.Series
shap_features: pd.DataFrame, target: pd.Series, alpha: float
) -> pd.DataFrame:
"""
Fits a linear regression model using the features from `shap_features` to predict the continuous `target`.
Expand All @@ -158,7 +159,7 @@ def regression_significance(
"""
# Fit the linear regression model that will generate confidence intervals
ols_model = sm.OLS(target, shap_features)
result = ols_model.fit_regularized(alpha=1e-6, refit=True)
result = ols_model.fit_regularized(alpha=alpha, refit=True)

# Extract the results
summary_frame = result.summary2().tables[1]
Expand Down Expand Up @@ -186,6 +187,7 @@ def shap_features_to_significance(
shap_features: pd.DataFrame | List[pd.DataFrame],
target: pd.Series,
task: str,
alpha: float,
) -> pd.DataFrame:
"""
Determines the task (regression, binary, or multi-class classification) based on the target and calls the appropriate
Expand All @@ -205,11 +207,11 @@ def shap_features_to_significance(

# Call the appropriate function based on the task
if task == "regression":
result_df = regression_significance(shap_features, target)
result_df = regression_significance(shap_features, target, alpha)
elif task == "binary":
result_df = binary_classifier_significance(shap_features, target)
result_df = binary_classifier_significance(shap_features, target, alpha)
elif task == "multiclass":
result_df = multi_classifier_significance(shap_features, target)
result_df = multi_classifier_significance(shap_features, target, alpha)
else:
raise ValueError("`task` must be 'regression', 'binary', 'multiclass' or None.")

Expand All @@ -229,6 +231,7 @@ def shap_select(
task: str | None = None,
threshold: float = 0.05,
return_extended_data: bool = False,
alpha: float = 1e-6,
) -> pd.DataFrame | Tuple[pd.DataFrame, pd.DataFrame]:
"""
Select features based on their SHAP values and statistical significance.
Expand All @@ -241,6 +244,7 @@ def shap_select(
- task (str | None): The task type ('regression', 'binary', or 'multiclass'). If None, it is inferred automatically.
- threshold (float): Significance threshold to select features. Default is 0.05.
- return_extended_data (bool): Whether to also return the shapley values dataframe(s) and some extra columns
- alpha (float): Controls the regularization strength for the regression
Returns:
- pd.DataFrame: A DataFrame containing the feature names, statistical significance, and a 'Selected' column
Expand Down Expand Up @@ -271,7 +275,7 @@ def shap_select(
shap_features = create_shap_features(tree_model, validation_df[feature_names])

# Compute statistical significance of each feature
significance_df = shap_features_to_significance(shap_features, target, task)
significance_df = shap_features_to_significance(shap_features, target, task, alpha)

# Add 'Selected' column based on the threshold
significance_df["selected"] = (
Expand Down

0 comments on commit 29e0847

Please sign in to comment.