diff --git a/LICENSE b/LICENSE index b6d1bf0..550d11d 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [2024] [Wise PLC] + Copyright 2024 Wise PLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/docs/bug.py b/docs/bug.py deleted file mode 100644 index e69de29..0000000 diff --git a/shap_select/select.py b/shap_select/select.py index 715723a..2c2229c 100644 --- a/shap_select/select.py +++ b/shap_select/select.py @@ -223,6 +223,46 @@ def shap_features_to_significance( return result_df_sorted +def iterative_shap_feature_reduction( + shap_features: pd.DataFrame | List[pd.DataFrame], + target: pd.Series, + task: str, + alpha: float=1e-6, +) -> pd.DataFrame: + collected_rows = [] # List to store the rows we collect during each iteration + + features_left = True + while features_left: + # Call the original shap_features_to_significance function + significance_df = shap_features_to_significance(shap_features, target, task, alpha) + + # Find the feature with the lowest t-value + min_t_value_row = significance_df.loc[significance_df["t-value"].idxmin()] + + # Remember this row (collect it in our list) + collected_rows.append(min_t_value_row) + + # Drop the feature corresponding to the lowest t-value from shap_features + feature_to_remove = min_t_value_row["feature name"] + if isinstance(shap_features, pd.DataFrame): + shap_features = shap_features.drop(columns=[feature_to_remove]) + features_left = len(shap_features.columns) + else: + shap_features = { + k: v.drop(columns=[feature_to_remove]) for k, v in shap_features.items() + } + features_left = len(list(shap_features.values())[0].columns) + + # Convert collected rows back to a dataframe + result_df = ( + pd.DataFrame(collected_rows) + .sort_values(by="t-value", ascending=False) + .reset_index() + ) + + return result_df + + def shap_select( tree_model: Any, validation_df: pd.DataFrame, @@ -274,8 +314,8 @@ def shap_select( else: shap_features = create_shap_features(tree_model, validation_df[feature_names]) - # Compute statistical significance of each feature - significance_df = shap_features_to_significance(shap_features, target, task, alpha) + # Compute statistical significance of each feature, recursively ablating + significance_df = iterative_shap_feature_reduction(shap_features, target, task, alpha) # Add 'Selected' column based on the threshold significance_df["selected"] = (