Skip to content

Commit

Permalink
Evaluate pilot (#55)
Browse files Browse the repository at this point in the history
* Add validation analysis to notebooks

* Updated evaluation

* Fix some changes from meeting Julius
  • Loading branch information
rubenpeters91 authored Jan 11, 2024
1 parent fff117d commit 1c87d8a
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 42 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [1.0.5] - 2023-11-27

### Changed
- Updated evaluation notebooks
- Start using nbstripout for removing notebook output

## [1.0.4] - 2023-11-21

### Changed
Expand Down
53 changes: 36 additions & 17 deletions notebooks/model_evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@
"outputs": [],
"source": [
"import pickle\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import relplot as rp\n",
"from sklearn.calibration import calibration_curve\n",
"from sklearn.metrics import (\n",
" roc_auc_score,\n",
" roc_curve,\n",
" precision_recall_curve,\n",
" precision_score,\n",
" recall_score,\n",
" roc_auc_score,\n",
" roc_curve,\n",
")\n",
"from sklearn.calibration import calibration_curve\n",
"import matplotlib.pyplot as plt\n"
"from sklearn.model_selection import train_test_split"
]
},
{
Expand Down Expand Up @@ -68,7 +69,7 @@
"fpr, tpr, thresholds = roc_curve(y_test, y_pred[:, 1])\n",
"auc_score = roc_auc_score(y_test, y_pred[:, 1])\n",
"fig, ax = plt.subplots(figsize=(8, 8))\n",
"ax.plot(fpr, tpr, label=f\"Random Forest (AUC={round(auc_score, 2)})\")\n",
"ax.plot(fpr, tpr, label=f\"Hist Gradient Boosting (AUC={round(auc_score, 2)})\")\n",
"ax.plot([0, 1], [0, 1], label=\"Random (AUC=0.5)\", linestyle=\"dotted\")\n",
"ax.legend()\n",
"plt.show()"
Expand All @@ -94,7 +95,7 @@
"ax.plot(thresholds, precision[:-1], label=\"precision\")\n",
"ax.plot(thresholds, recall[:-1], label=\"recall\")\n",
"ax.legend()\n",
"plt.show()\n"
"plt.show()"
]
},
{
Expand All @@ -106,7 +107,7 @@
"X_test.resample(\"1D\", level=\"start\")[\"age\"].count().plot.hist(\n",
" title=\"Number of appointments per day\"\n",
")\n",
"plt.show()\n"
"plt.show()"
]
},
{
Expand Down Expand Up @@ -152,7 +153,7 @@
" precisions.append(prec)\n",
" precisions_random.append(prec_random)\n",
" recalls.append(rec)\n",
" recalls_random.append(rec_random)\n"
" recalls_random.append(rec_random)"
]
},
{
Expand Down Expand Up @@ -246,7 +247,26 @@
"ax.set_ylabel(\"Fraction of positives\")\n",
"ax.set_title(\"Calibration curve\")\n",
"ax.legend()\n",
"plt.show()\n"
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reliability plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred_total = model.predict_proba(X)\n",
"print(\"calibration error:\", rp.smECE(y_pred_total[:, 1], y))\n",
"fig, ax = rp.rel_diagram(y_pred_total[:, 1], y)\n",
"fig.show()"
]
},
{
Expand Down Expand Up @@ -291,7 +311,7 @@
"metadata": {},
"outputs": [],
"source": [
"total_test_data.sort_values(\"y_pred\").head()\n"
"total_test_data.sort_values(\"y_pred\").head()"
]
},
{
Expand All @@ -311,7 +331,7 @@
"metadata": {},
"outputs": [],
"source": [
"from ipywidgets import interact, IntSlider, FloatSlider\n",
"from ipywidgets import FloatSlider, IntSlider, interact\n",
"\n",
"\n",
"@interact\n",
Expand Down Expand Up @@ -353,7 +373,7 @@
" index=[0],\n",
" )\n",
"\n",
" print(f\"Predicted value is: {model.predict_proba(prediction_df)[:,1]}\")\n"
" print(f\"Predicted value is: {model.predict_proba(prediction_df)[:,1]}\")"
]
}
],
Expand All @@ -374,8 +394,7 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
}
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
166 changes: 145 additions & 21 deletions notebooks/pilot_compare_noshow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,42 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pickle\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from dotenv import load_dotenv\n",
"from IPython.display import display\n",
"from ipywidgets import interact\n",
"from sqlalchemy import create_engine\n",
"from sqlalchemy.orm import sessionmaker\n",
"\n",
"from noshow.features.feature_pipeline import create_features, select_feature_columns\n",
"from noshow.model.predict import create_prediction\n",
"from noshow.preprocessing.load_data import (\n",
" load_appointment_csv,\n",
" process_appointments,\n",
" process_postal_codes,\n",
")\n",
"from noshow.features.feature_pipeline import create_features, select_feature_columns\n",
"from noshow.model.predict import create_prediction\n",
"import matplotlib.pyplot as plt\n",
"import pickle\n",
"import pandas as pd\n",
"import seaborn as sns"
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(\"../.env\")\n",
"\n",
"\n",
"# Global and env variables\n",
"db_user = os.environ[\"DB_USER\"]\n",
"db_passwd = os.environ[\"DB_PASSWD\"]\n",
"db_host = os.environ[\"DB_HOST\"]\n",
"db_port = os.environ[\"DB_PORT\"]\n",
"db_database = os.environ[\"DB_DATABASE\"]"
]
},
{
Expand All @@ -41,6 +66,34 @@
"appointments_df.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CONNECTSTRING = (\n",
" rf\"mssql+pymssql://{db_user}:{db_passwd}@{db_host}:{db_port}/{db_database}\"\n",
")\n",
"engine = create_engine(CONNECTSTRING)\n",
"session_object = sessionmaker(bind=engine)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"call_response = pd.read_sql_table(\"apicallresponse\", engine, schema=\"noshow\")\n",
"prediction = pd.read_sql_table(\"apiprediction\", engine, schema=\"noshow\")\n",
"\n",
"prediction_response = prediction.merge(\n",
" call_response, left_on=\"id\", right_on=\"prediction_id\", how=\"inner\"\n",
")\n",
"prediction_response"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -100,12 +153,16 @@
"outputs": [],
"source": [
"appointments_df[\"month\"] = appointments_df.index.get_level_values(\"start\").month\n",
"(\n",
"tmp_df = (\n",
" appointments_df.groupby([\"month\", \"pilot\"])[\"no_show\"]\n",
" .value_counts(True)\n",
" .unstack([\"no_show\", \"pilot\"])[\"no_show\"]\n",
" .plot.bar(figsize=(15, 6))\n",
")"
")\n",
"tmp_df.loc[11, \"pilot\"] = None\n",
"tmp_df.plot.bar(figsize=(15, 6))\n",
"plt.xlabel(\"\")\n",
"plt.title(\"Gem. no-show percentage per maand sinds 2015\")\n",
"plt.show()"
]
},
{
Expand All @@ -122,13 +179,16 @@
"outputs": [],
"source": [
"for agenda in appointments_df[\"hoofdagenda\"].unique():\n",
" (\n",
" tmp_df = (\n",
" appointments_df.loc[appointments_df[\"hoofdagenda\"] == agenda]\n",
" .groupby([\"month\", \"pilot\"])[\"no_show\"]\n",
" .value_counts(True)\n",
" .unstack([\"no_show\", \"pilot\"])[\"no_show\"]\n",
" .plot.bar(figsize=(15, 6), title=agenda)\n",
" )\n",
" tmp_df.loc[11, \"pilot\"] = None\n",
" tmp_df.plot.bar(figsize=(15, 6))\n",
" plt.xlabel(\"\")\n",
" plt.title(agenda)\n",
" plt.show()"
]
},
Expand Down Expand Up @@ -171,9 +231,7 @@
"outputs": [],
"source": [
"total_appointments = appointments_df.join(predictions_df, how=\"inner\")\n",
"total_appointments[\"predict_bin\"] = pd.cut(\n",
" total_appointments[\"prediction\"], bins=[0, 0.05, 0.1, 0.15, 0.2, 0.25, 1]\n",
")"
"total_appointments[\"predict_bin\"] = pd.qcut(total_appointments[\"prediction\"], 12)"
]
},
{
Expand Down Expand Up @@ -207,12 +265,78 @@
"metadata": {},
"outputs": [],
"source": [
"plt.subplots(figsize=(15, 6))\n",
"sns.barplot(data=total_appointments, x=\"predict_bin\", y=\"noshow_num\", hue=\"pilot\")\n",
"plt.title(\"No-Show percentage per risico-categorie\")\n",
"plt.xlabel(\"Risico-categorieen\")\n",
"plt.ylabel(\"Percentage No-Show\")\n",
"plt.show()"
"fig, ax = plt.subplots(2, 1, figsize=(15, 8), sharex=True)\n",
"total_appointments[[\"prediction\", \"pilot\"]].plot.hist(by=\"pilot\", bins=100, ax=ax)\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@interact\n",
"def no_show_perc_plot(\n",
" years=[2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022],\n",
" show_all=False,\n",
" only_called=False,\n",
"):\n",
" if show_all:\n",
" total_appointments_selection = total_appointments\n",
" else:\n",
" year_selection = [years, 2023]\n",
" total_appointments_selection = total_appointments[\n",
" total_appointments.index.get_level_values(\"start\").year.isin(year_selection)\n",
" ]\n",
"\n",
" if only_called:\n",
" total_appointments_selection = total_appointments_selection.loc[\n",
" (total_appointments_selection[\"pilot\"] == \"Geen pilot\")\n",
" | total_appointments_selection[\"APP_ID\"].isin(\n",
" prediction_response[\"prediction_id\"].astype(int)\n",
" )\n",
" ]\n",
"\n",
" plt.subplots(figsize=(15, 6))\n",
" sns.barplot(\n",
" data=total_appointments_selection,\n",
" x=\"predict_bin\",\n",
" y=\"noshow_num\",\n",
" hue=\"pilot\",\n",
" hue_order=[\"Geen pilot\", \"pilot\"],\n",
" )\n",
" plt.title(\"No-Show percentage per risico-categorie\")\n",
" plt.xlabel(\"Risico-categorieen\")\n",
" plt.ylabel(\"Percentage No-Show\")\n",
" plt.show()\n",
"\n",
" total_appointments_plot = total_appointments_selection.groupby(\n",
" [\"pilot\", \"predict_bin\"]\n",
" )[\"noshow_num\"].agg([\"mean\", \"std\", \"size\"])\n",
" display(total_appointments_plot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prediction_response.columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prediction_response[\"id_x\"] = prediction_response[\"id_x\"].astype(\"Int64\")\n",
"tmp = total_appointments.merge(prediction_response, left_on=\"APP_ID\", right_on=\"id_x\")\n",
"tmp = tmp[tmp[\"call_status\"] == \"Gebeld\"]\n",
"tmp[\"y\"] = tmp[\"call_outcome\"] == \"Verzet/Geannuleerd\"\n",
"tmp"
]
}
],
Expand Down
Loading

0 comments on commit 1c87d8a

Please sign in to comment.