Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the new Shap explanation api #94

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open

Conversation

pemujo
Copy link
Collaborator

@pemujo pemujo commented Mar 6, 2025

Updated the pdp_model_inference task to use the "New Shap Explanation API" and plotted using beeswarm function.

The changes were required because some methods used with the old API are getting deprecated and Shap documentation pointed to the new API.

Changes

1. Removed the custom code to create shap values

shap_schema = StructType(
    [StructField(cfg.student_id_col, StringType(), nullable=False)]
    + [StructField(col, FloatType(), nullable=False) for col in model_feature_names]
)

df_shap_values = (
    spark_session.createDataFrame(
        df_processed_dataset.reindex(
            columns=model_feature_names + [cfg.student_id_col]
        )
    )
    .repartition(spark_session.sparkContext.defaultParallelism)
    .mapInPandas(
        ft.partial(
            inference.calculate_shap_values_spark_udf,
            student_id_col=cfg.student_id_col,
            model_features=model_feature_names,
            explainer=explainer,
            mode=train_mode,
        ),
        schema=shap_schema,
    )
    .toPandas()
    .set_index(cfg.student_id_col)
    .reindex(df_processed_dataset[cfg.student_id_col])
    .reset_index(drop=False)
)

Replaced with the new explanation API

    # Calculate Shap values using the new Shap Explanation API
    df_shap_values = explainer(df_processed_dataset[model_feature_names])

2. Removed the plot function

shap.summary_plot(
      df_shap_values.loc[:, model_feature_names].to_numpy(),
      df_serving_dataset.loc[:, model_feature_names],
      class_names=loaded_model.classes_,
      max_display=20,
      show=False,
  )

replaced with beeswarm utils function

from student_success_tool.pipeline_utils.plot import plot_shap_beeswarm
# Plot Shap values using the beeswarm plot function 
 shap_fig = `plot_shap_beeswarm(df_shap_values)`
  1. Removed the unsed dependencies.

Context

Per documentation:
"The new-style plotting functions like shap.plot.bar and shap.plots.beeswarm accept these Explanation objections rather than numpy arrays."

Questions/Inputs

The performance was similar to the custom spark function with synthetic datasets, but it should be tested with bigger datasets

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant