diff --git a/assignments/eval.ipynb b/assignments/eval.ipynb
deleted file mode 100644
index ce9dff2..0000000
--- a/assignments/eval.ipynb
+++ /dev/null
@@ -1,1245 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "f7efc033",
- "metadata": {},
- "source": [
- "## License \n",
- "\n",
- "Copyright 2021-2023 Patrick Hall (jphall@gwu.edu)\n",
- "\n",
- "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
- "you may not use this file except in compliance with the License.\n",
- "You may obtain a copy of the License at\n",
- "\n",
- " http://www.apache.org/licenses/LICENSE-2.0\n",
- "\n",
- "Unless required by applicable law or agreed to in writing, software\n",
- "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
- "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
- "See the License for the specific language governing permissions and\n",
- "limitations under the License.\n",
- "\n",
- "*DISCLAIMER*: This notebook is not legal or compliance advice."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "aab60b41",
- "metadata": {},
- "source": [
- "# Model Evaluation Notebook"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "281af306",
- "metadata": {},
- "source": [
- "#### Imports and inits"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "fd180587",
- "metadata": {},
- "outputs": [],
- "source": [
- "import os # for directory and file manipulation\n",
- "import numpy as np # for basic array manipulation\n",
- "import pandas as pd # for dataframe manipulation\n",
- "import datetime # for timestamp\n",
- "\n",
- "# for model eval\n",
- "from sklearn.metrics import accuracy_score, f1_score, log_loss, mean_squared_error, roc_auc_score\n",
- "\n",
- "# global constants \n",
- "ROUND = 3 # generally, insane precision is not needed \n",
- "SEED = 12345 # seed for better reproducibility\n",
- "\n",
- "# set global random seed for better reproducibility\n",
- "np.random.seed(SEED)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "eb2a39d4",
- "metadata": {},
- "source": [
- "#### Set basic metadata"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "98f640ed",
- "metadata": {},
- "outputs": [],
- "source": [
- "y_name = 'high_priced'\n",
- "scores_dir = 'data/scores'"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "cc8d83d0",
- "metadata": {},
- "source": [
- "#### Read in score files "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "355c2b81",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " high_priced | \n",
- " fold | \n",
- " group1_rem_ebm | \n",
- " group2_rem_ebm | \n",
- " group3_rem_piml_EBM | \n",
- " group5_rem_xgb2 | \n",
- " group8_rem_ebm | \n",
- " group9_rem_xgb | \n",
- " ph_rem_ebm | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 0.0 | \n",
- " 2 | \n",
- " 0.118787 | \n",
- " 0.080557 | \n",
- " 0.920389 | \n",
- " 0.078326 | \n",
- " 0.223846 | \n",
- " 0.081792 | \n",
- " 0.219429 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 0.0 | \n",
- " 1 | \n",
- " 0.084506 | \n",
- " 0.026001 | \n",
- " 0.969301 | \n",
- " 0.035825 | \n",
- " 0.053926 | \n",
- " 0.110702 | \n",
- " 0.053929 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 1.0 | \n",
- " 4 | \n",
- " 0.210389 | \n",
- " 0.194961 | \n",
- " 0.814272 | \n",
- " 0.195332 | \n",
- " 0.143522 | \n",
- " 0.204048 | \n",
- " 0.133863 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 0.0 | \n",
- " 1 | \n",
- " 0.008529 | \n",
- " 0.028556 | \n",
- " 0.974559 | \n",
- " 0.022765 | \n",
- " 0.009371 | \n",
- " 0.024038 | \n",
- " 0.014419 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 1.0 | \n",
- " 2 | \n",
- " 0.189933 | \n",
- " 0.208263 | \n",
- " 0.802908 | \n",
- " 0.193035 | \n",
- " 0.151100 | \n",
- " 0.170243 | \n",
- " 0.156047 | \n",
- "
\n",
- " \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- "
\n",
- " \n",
- " 19826 | \n",
- " 0.0 | \n",
- " 3 | \n",
- " 0.163697 | \n",
- " 0.228342 | \n",
- " 0.792251 | \n",
- " 0.235192 | \n",
- " 0.216720 | \n",
- " 0.181403 | \n",
- " 0.184214 | \n",
- "
\n",
- " \n",
- " 19827 | \n",
- " 0.0 | \n",
- " 1 | \n",
- " 0.114999 | \n",
- " 0.253998 | \n",
- " 0.762946 | \n",
- " 0.235832 | \n",
- " 0.161401 | \n",
- " 0.159468 | \n",
- " 0.141663 | \n",
- "
\n",
- " \n",
- " 19828 | \n",
- " 1.0 | \n",
- " 3 | \n",
- " 0.141307 | \n",
- " 0.213364 | \n",
- " 0.747401 | \n",
- " 0.208723 | \n",
- " 0.242814 | \n",
- " 0.138141 | \n",
- " 0.233266 | \n",
- "
\n",
- " \n",
- " 19829 | \n",
- " 0.0 | \n",
- " 1 | \n",
- " 0.007766 | \n",
- " 0.002176 | \n",
- " 0.996455 | \n",
- " 0.018702 | \n",
- " 0.005657 | \n",
- " 0.034570 | \n",
- " 0.009914 | \n",
- "
\n",
- " \n",
- " 19830 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " 0.163946 | \n",
- " 0.185484 | \n",
- " 0.811429 | \n",
- " 0.215085 | \n",
- " 0.167812 | \n",
- " 0.177785 | \n",
- " 0.155447 | \n",
- "
\n",
- " \n",
- "
\n",
- "
19831 rows × 9 columns
\n",
- "
"
- ],
- "text/plain": [
- " high_priced fold group1_rem_ebm group2_rem_ebm group3_rem_piml_EBM \\\n",
- "0 0.0 2 0.118787 0.080557 0.920389 \n",
- "1 0.0 1 0.084506 0.026001 0.969301 \n",
- "2 1.0 4 0.210389 0.194961 0.814272 \n",
- "3 0.0 1 0.008529 0.028556 0.974559 \n",
- "4 1.0 2 0.189933 0.208263 0.802908 \n",
- "... ... ... ... ... ... \n",
- "19826 0.0 3 0.163697 0.228342 0.792251 \n",
- "19827 0.0 1 0.114999 0.253998 0.762946 \n",
- "19828 1.0 3 0.141307 0.213364 0.747401 \n",
- "19829 0.0 1 0.007766 0.002176 0.996455 \n",
- "19830 0.0 0 0.163946 0.185484 0.811429 \n",
- "\n",
- " group5_rem_xgb2 group8_rem_ebm group9_rem_xgb ph_rem_ebm \n",
- "0 0.078326 0.223846 0.081792 0.219429 \n",
- "1 0.035825 0.053926 0.110702 0.053929 \n",
- "2 0.195332 0.143522 0.204048 0.133863 \n",
- "3 0.022765 0.009371 0.024038 0.014419 \n",
- "4 0.193035 0.151100 0.170243 0.156047 \n",
- "... ... ... ... ... \n",
- "19826 0.235192 0.216720 0.181403 0.184214 \n",
- "19827 0.235832 0.161401 0.159468 0.141663 \n",
- "19828 0.208723 0.242814 0.138141 0.233266 \n",
- "19829 0.018702 0.005657 0.034570 0.009914 \n",
- "19830 0.215085 0.167812 0.177785 0.155447 \n",
- "\n",
- "[19831 rows x 9 columns]"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# init score frame with known test y values\n",
- "scores_frame = pd.read_csv(scores_dir + os.sep +'key.csv', index_col='Unnamed: 0')\n",
- "\n",
- "# create random folds in reproducible way\n",
- "np.random.seed(SEED)\n",
- "scores_frame['fold'] = np.random.choice(5, scores_frame.shape[0])\n",
- "\n",
- "# read in each score file in the directory as a new column \n",
- "for file in sorted(os.listdir(scores_dir)):\n",
- " if file != 'key.csv' and file.endswith('.csv'):\n",
- " scores_frame[file[:-4]] = pd.read_csv(scores_dir + os.sep + file)['phat']\n",
- "\n",
- "# sanity check \n",
- "scores_frame"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "3e3cccda",
- "metadata": {},
- "source": [
- "#### Utility function for max. accuracy"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "2eb43506",
- "metadata": {},
- "outputs": [],
- "source": [
- "def max_acc(y, phat, res=0.01): \n",
- "\n",
- " \"\"\" Utility function for finding max. accuracy at some cutoff. \n",
- " \n",
- " :param y: Known y values.\n",
- " :param phat: Model scores.\n",
- " :param res: Resolution over which to search for max. accuracy, default 0.01.\n",
- " :return: Max. accuracy for model scores.\n",
- " \n",
- " \"\"\"\n",
- " \n",
- " # init frame to store acc at different cutoffs\n",
- " acc_frame = pd.DataFrame(columns=['cut', 'acc'])\n",
- " \n",
- " # copy known y and score values into a temporary frame\n",
- " temp_df = pd.concat([y, phat], axis=1)\n",
- " \n",
- " # find accuracy at different cutoffs and store in acc_frame\n",
- " for cut in np.arange(0, 1 + res, res):\n",
- " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n",
- " acc = accuracy_score(temp_df.iloc[:, 0], temp_df['decision'])\n",
- " acc_frame = acc_frame.append({'cut': cut,\n",
- " 'acc': acc},\n",
- " ignore_index=True)\n",
- "\n",
- " # find max accurcay across all cutoffs\n",
- " max_acc = acc_frame['acc'].max()\n",
- " \n",
- " # house keeping\n",
- " del acc_frame, temp_df\n",
- " \n",
- " return max_acc"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b02c9651",
- "metadata": {},
- "source": [
- "#### Utility function for max. F1"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "fae3756b",
- "metadata": {},
- "outputs": [],
- "source": [
- "def max_f1(y, phat, res=0.01): \n",
- " \n",
- " \"\"\" Utility function for finding max. F1 at some cutoff. \n",
- " \n",
- " :param y: Known y values.\n",
- " :param phat: Model scores.\n",
- " :param res: Resolution over which to search for max. F1, default 0.01.\n",
- " :return: Max. F1 for model scores.\n",
- " \n",
- " \"\"\"\n",
- " \n",
- " # init frame to store f1 at different cutoffs\n",
- " f1_frame = pd.DataFrame(columns=['cut', 'f1'])\n",
- " \n",
- " # copy known y and score values into a temporary frame\n",
- " temp_df = pd.concat([y, phat], axis=1)\n",
- " \n",
- " # find f1 at different cutoffs and store in acc_frame\n",
- " for cut in np.arange(0, 1 + res, res):\n",
- " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n",
- " f1 = f1_score(temp_df.iloc[:, 0], temp_df['decision'])\n",
- " f1_frame = f1_frame.append({'cut': cut,\n",
- " 'f1': f1},\n",
- " ignore_index=True)\n",
- " \n",
- " # find max f1 across all cutoffs\n",
- " max_f1 = f1_frame['f1'].max()\n",
- " \n",
- " # house keeping\n",
- " del f1_frame, temp_df\n",
- " \n",
- " return max_f1"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b447b732",
- "metadata": {},
- "source": [
- "#### Rank all submitted scores "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "40fbe608",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " fold | \n",
- " metric | \n",
- " group1_rem_ebm | \n",
- " group2_rem_ebm | \n",
- " group3_rem_piml_EBM | \n",
- " group5_rem_xgb2 | \n",
- " group8_rem_ebm | \n",
- " group9_rem_xgb | \n",
- " ph_rem_ebm | \n",
- " group1_rem_ebm_rank | \n",
- " group2_rem_ebm_rank | \n",
- " group3_rem_piml_EBM_rank | \n",
- " group5_rem_xgb2_rank | \n",
- " group8_rem_ebm_rank | \n",
- " group9_rem_xgb_rank | \n",
- " ph_rem_ebm_rank | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 0.0 | \n",
- " acc | \n",
- " 0.900 | \n",
- " 0.901 | \n",
- " 0.900 | \n",
- " 0.901 | \n",
- " 0.901 | \n",
- " 0.900 | \n",
- " 0.901 | \n",
- " 6.0 | \n",
- " 2.5 | \n",
- " 6.0 | \n",
- " 2.5 | \n",
- " 2.5 | \n",
- " 6.0 | \n",
- " 2.5 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 0.0 | \n",
- " auc | \n",
- " 0.781 | \n",
- " 0.840 | \n",
- " 0.163 | \n",
- " 0.836 | \n",
- " 0.793 | \n",
- " 0.797 | \n",
- " 0.791 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 4.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 0.0 | \n",
- " f1 | \n",
- " 0.347 | \n",
- " 0.405 | \n",
- " 0.182 | \n",
- " 0.392 | \n",
- " 0.342 | \n",
- " 0.357 | \n",
- " 0.347 | \n",
- " 4.5 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 6.0 | \n",
- " 3.0 | \n",
- " 4.5 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 0.0 | \n",
- " logloss | \n",
- " 0.280 | \n",
- " 0.251 | \n",
- " 3.257 | \n",
- " 0.254 | \n",
- " 0.274 | \n",
- " 0.277 | \n",
- " 0.275 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 0.0 | \n",
- " mse | \n",
- " 0.082 | \n",
- " 0.077 | \n",
- " 0.773 | \n",
- " 0.077 | \n",
- " 0.081 | \n",
- " 0.081 | \n",
- " 0.081 | \n",
- " 6.0 | \n",
- " 1.5 | \n",
- " 7.0 | \n",
- " 1.5 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 1.0 | \n",
- " acc | \n",
- " 0.906 | \n",
- " 0.906 | \n",
- " 0.906 | \n",
- " 0.906 | \n",
- " 0.906 | \n",
- " 0.906 | \n",
- " 0.906 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 1.0 | \n",
- " auc | \n",
- " 0.767 | \n",
- " 0.828 | \n",
- " 0.172 | \n",
- " 0.822 | \n",
- " 0.774 | \n",
- " 0.779 | \n",
- " 0.772 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 4.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 1.0 | \n",
- " f1 | \n",
- " 0.312 | \n",
- " 0.368 | \n",
- " 0.172 | \n",
- " 0.360 | \n",
- " 0.319 | \n",
- " 0.329 | \n",
- " 0.321 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 5.0 | \n",
- " 3.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 1.0 | \n",
- " logloss | \n",
- " 0.272 | \n",
- " 0.246 | \n",
- " 3.253 | \n",
- " 0.250 | \n",
- " 0.270 | \n",
- " 0.271 | \n",
- " 0.272 | \n",
- " 5.5 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 3.0 | \n",
- " 4.0 | \n",
- " 5.5 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 1.0 | \n",
- " mse | \n",
- " 0.079 | \n",
- " 0.074 | \n",
- " 0.778 | \n",
- " 0.075 | \n",
- " 0.079 | \n",
- " 0.078 | \n",
- " 0.079 | \n",
- " 5.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 5.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " 2.0 | \n",
- " acc | \n",
- " 0.908 | \n",
- " 0.908 | \n",
- " 0.908 | \n",
- " 0.910 | \n",
- " 0.908 | \n",
- " 0.908 | \n",
- " 0.909 | \n",
- " 5.0 | \n",
- " 5.0 | \n",
- " 5.0 | \n",
- " 1.0 | \n",
- " 5.0 | \n",
- " 5.0 | \n",
- " 2.0 | \n",
- "
\n",
- " \n",
- " 11 | \n",
- " 2.0 | \n",
- " auc | \n",
- " 0.759 | \n",
- " 0.825 | \n",
- " 0.175 | \n",
- " 0.826 | \n",
- " 0.781 | \n",
- " 0.772 | \n",
- " 0.780 | \n",
- " 6.0 | \n",
- " 2.0 | \n",
- " 7.0 | \n",
- " 1.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 2.0 | \n",
- " f1 | \n",
- " 0.304 | \n",
- " 0.372 | \n",
- " 0.169 | \n",
- " 0.371 | \n",
- " 0.315 | \n",
- " 0.320 | \n",
- " 0.323 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 5.0 | \n",
- " 4.0 | \n",
- " 3.0 | \n",
- "
\n",
- " \n",
- " 13 | \n",
- " 2.0 | \n",
- " logloss | \n",
- " 0.271 | \n",
- " 0.246 | \n",
- " 3.284 | \n",
- " 0.245 | \n",
- " 0.264 | \n",
- " 0.271 | \n",
- " 0.264 | \n",
- " 5.5 | \n",
- " 2.0 | \n",
- " 7.0 | \n",
- " 1.0 | \n",
- " 3.5 | \n",
- " 5.5 | \n",
- " 3.5 | \n",
- "
\n",
- " \n",
- " 14 | \n",
- " 2.0 | \n",
- " mse | \n",
- " 0.078 | \n",
- " 0.073 | \n",
- " 0.781 | \n",
- " 0.073 | \n",
- " 0.076 | \n",
- " 0.077 | \n",
- " 0.076 | \n",
- " 6.0 | \n",
- " 1.5 | \n",
- " 7.0 | \n",
- " 1.5 | \n",
- " 3.5 | \n",
- " 5.0 | \n",
- " 3.5 | \n",
- "
\n",
- " \n",
- " 15 | \n",
- " 3.0 | \n",
- " acc | \n",
- " 0.903 | \n",
- " 0.903 | \n",
- " 0.903 | \n",
- " 0.903 | \n",
- " 0.903 | \n",
- " 0.903 | \n",
- " 0.903 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 16 | \n",
- " 3.0 | \n",
- " auc | \n",
- " 0.772 | \n",
- " 0.826 | \n",
- " 0.174 | \n",
- " 0.823 | \n",
- " 0.775 | \n",
- " 0.786 | \n",
- " 0.772 | \n",
- " 5.5 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 4.0 | \n",
- " 3.0 | \n",
- " 5.5 | \n",
- "
\n",
- " \n",
- " 17 | \n",
- " 3.0 | \n",
- " f1 | \n",
- " 0.317 | \n",
- " 0.371 | \n",
- " 0.177 | \n",
- " 0.365 | \n",
- " 0.328 | \n",
- " 0.343 | \n",
- " 0.323 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 4.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- "
\n",
- " \n",
- " 18 | \n",
- " 3.0 | \n",
- " logloss | \n",
- " 0.276 | \n",
- " 0.252 | \n",
- " 3.254 | \n",
- " 0.253 | \n",
- " 0.275 | \n",
- " 0.275 | \n",
- " 0.276 | \n",
- " 5.5 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 3.5 | \n",
- " 3.5 | \n",
- " 5.5 | \n",
- "
\n",
- " \n",
- " 19 | \n",
- " 3.0 | \n",
- " mse | \n",
- " 0.081 | \n",
- " 0.077 | \n",
- " 0.775 | \n",
- " 0.077 | \n",
- " 0.080 | \n",
- " 0.080 | \n",
- " 0.080 | \n",
- " 6.0 | \n",
- " 1.5 | \n",
- " 7.0 | \n",
- " 1.5 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 20 | \n",
- " 4.0 | \n",
- " acc | \n",
- " 0.895 | \n",
- " 0.897 | \n",
- " 0.895 | \n",
- " 0.898 | \n",
- " 0.895 | \n",
- " 0.896 | \n",
- " 0.895 | \n",
- " 5.5 | \n",
- " 2.0 | \n",
- " 5.5 | \n",
- " 1.0 | \n",
- " 5.5 | \n",
- " 3.0 | \n",
- " 5.5 | \n",
- "
\n",
- " \n",
- " 21 | \n",
- " 4.0 | \n",
- " auc | \n",
- " 0.754 | \n",
- " 0.831 | \n",
- " 0.170 | \n",
- " 0.828 | \n",
- " 0.785 | \n",
- " 0.779 | \n",
- " 0.782 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 22 | \n",
- " 4.0 | \n",
- " f1 | \n",
- " 0.323 | \n",
- " 0.401 | \n",
- " 0.190 | \n",
- " 0.397 | \n",
- " 0.364 | \n",
- " 0.354 | \n",
- " 0.362 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 23 | \n",
- " 4.0 | \n",
- " logloss | \n",
- " 0.296 | \n",
- " 0.263 | \n",
- " 3.200 | \n",
- " 0.266 | \n",
- " 0.286 | \n",
- " 0.291 | \n",
- " 0.287 | \n",
- " 6.0 | \n",
- " 1.0 | \n",
- " 7.0 | \n",
- " 2.0 | \n",
- " 3.0 | \n",
- " 5.0 | \n",
- " 4.0 | \n",
- "
\n",
- " \n",
- " 24 | \n",
- " 4.0 | \n",
- " mse | \n",
- " 0.087 | \n",
- " 0.080 | \n",
- " 0.771 | \n",
- " 0.080 | \n",
- " 0.084 | \n",
- " 0.086 | \n",
- " 0.084 | \n",
- " 6.0 | \n",
- " 1.5 | \n",
- " 7.0 | \n",
- " 1.5 | \n",
- " 3.5 | \n",
- " 5.0 | \n",
- " 3.5 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " fold metric group1_rem_ebm group2_rem_ebm group3_rem_piml_EBM \\\n",
- "0 0.0 acc 0.900 0.901 0.900 \n",
- "1 0.0 auc 0.781 0.840 0.163 \n",
- "2 0.0 f1 0.347 0.405 0.182 \n",
- "3 0.0 logloss 0.280 0.251 3.257 \n",
- "4 0.0 mse 0.082 0.077 0.773 \n",
- "5 1.0 acc 0.906 0.906 0.906 \n",
- "6 1.0 auc 0.767 0.828 0.172 \n",
- "7 1.0 f1 0.312 0.368 0.172 \n",
- "8 1.0 logloss 0.272 0.246 3.253 \n",
- "9 1.0 mse 0.079 0.074 0.778 \n",
- "10 2.0 acc 0.908 0.908 0.908 \n",
- "11 2.0 auc 0.759 0.825 0.175 \n",
- "12 2.0 f1 0.304 0.372 0.169 \n",
- "13 2.0 logloss 0.271 0.246 3.284 \n",
- "14 2.0 mse 0.078 0.073 0.781 \n",
- "15 3.0 acc 0.903 0.903 0.903 \n",
- "16 3.0 auc 0.772 0.826 0.174 \n",
- "17 3.0 f1 0.317 0.371 0.177 \n",
- "18 3.0 logloss 0.276 0.252 3.254 \n",
- "19 3.0 mse 0.081 0.077 0.775 \n",
- "20 4.0 acc 0.895 0.897 0.895 \n",
- "21 4.0 auc 0.754 0.831 0.170 \n",
- "22 4.0 f1 0.323 0.401 0.190 \n",
- "23 4.0 logloss 0.296 0.263 3.200 \n",
- "24 4.0 mse 0.087 0.080 0.771 \n",
- "\n",
- " group5_rem_xgb2 group8_rem_ebm group9_rem_xgb ph_rem_ebm \\\n",
- "0 0.901 0.901 0.900 0.901 \n",
- "1 0.836 0.793 0.797 0.791 \n",
- "2 0.392 0.342 0.357 0.347 \n",
- "3 0.254 0.274 0.277 0.275 \n",
- "4 0.077 0.081 0.081 0.081 \n",
- "5 0.906 0.906 0.906 0.906 \n",
- "6 0.822 0.774 0.779 0.772 \n",
- "7 0.360 0.319 0.329 0.321 \n",
- "8 0.250 0.270 0.271 0.272 \n",
- "9 0.075 0.079 0.078 0.079 \n",
- "10 0.910 0.908 0.908 0.909 \n",
- "11 0.826 0.781 0.772 0.780 \n",
- "12 0.371 0.315 0.320 0.323 \n",
- "13 0.245 0.264 0.271 0.264 \n",
- "14 0.073 0.076 0.077 0.076 \n",
- "15 0.903 0.903 0.903 0.903 \n",
- "16 0.823 0.775 0.786 0.772 \n",
- "17 0.365 0.328 0.343 0.323 \n",
- "18 0.253 0.275 0.275 0.276 \n",
- "19 0.077 0.080 0.080 0.080 \n",
- "20 0.898 0.895 0.896 0.895 \n",
- "21 0.828 0.785 0.779 0.782 \n",
- "22 0.397 0.364 0.354 0.362 \n",
- "23 0.266 0.286 0.291 0.287 \n",
- "24 0.080 0.084 0.086 0.084 \n",
- "\n",
- " group1_rem_ebm_rank group2_rem_ebm_rank group3_rem_piml_EBM_rank \\\n",
- "0 6.0 2.5 6.0 \n",
- "1 6.0 1.0 7.0 \n",
- "2 4.5 1.0 7.0 \n",
- "3 6.0 1.0 7.0 \n",
- "4 6.0 1.5 7.0 \n",
- "5 4.0 4.0 4.0 \n",
- "6 6.0 1.0 7.0 \n",
- "7 6.0 1.0 7.0 \n",
- "8 5.5 1.0 7.0 \n",
- "9 5.0 1.0 7.0 \n",
- "10 5.0 5.0 5.0 \n",
- "11 6.0 2.0 7.0 \n",
- "12 6.0 1.0 7.0 \n",
- "13 5.5 2.0 7.0 \n",
- "14 6.0 1.5 7.0 \n",
- "15 4.0 4.0 4.0 \n",
- "16 5.5 1.0 7.0 \n",
- "17 6.0 1.0 7.0 \n",
- "18 5.5 1.0 7.0 \n",
- "19 6.0 1.5 7.0 \n",
- "20 5.5 2.0 5.5 \n",
- "21 6.0 1.0 7.0 \n",
- "22 6.0 1.0 7.0 \n",
- "23 6.0 1.0 7.0 \n",
- "24 6.0 1.5 7.0 \n",
- "\n",
- " group5_rem_xgb2_rank group8_rem_ebm_rank group9_rem_xgb_rank \\\n",
- "0 2.5 2.5 6.0 \n",
- "1 2.0 4.0 3.0 \n",
- "2 2.0 6.0 3.0 \n",
- "3 2.0 3.0 5.0 \n",
- "4 1.5 4.0 4.0 \n",
- "5 4.0 4.0 4.0 \n",
- "6 2.0 4.0 3.0 \n",
- "7 2.0 5.0 3.0 \n",
- "8 2.0 3.0 4.0 \n",
- "9 2.0 5.0 3.0 \n",
- "10 1.0 5.0 5.0 \n",
- "11 1.0 3.0 5.0 \n",
- "12 2.0 5.0 4.0 \n",
- "13 1.0 3.5 5.5 \n",
- "14 1.5 3.5 5.0 \n",
- "15 4.0 4.0 4.0 \n",
- "16 2.0 4.0 3.0 \n",
- "17 2.0 4.0 3.0 \n",
- "18 2.0 3.5 3.5 \n",
- "19 1.5 4.0 4.0 \n",
- "20 1.0 5.5 3.0 \n",
- "21 2.0 3.0 5.0 \n",
- "22 2.0 3.0 5.0 \n",
- "23 2.0 3.0 5.0 \n",
- "24 1.5 3.5 5.0 \n",
- "\n",
- " ph_rem_ebm_rank \n",
- "0 2.5 \n",
- "1 5.0 \n",
- "2 4.5 \n",
- "3 4.0 \n",
- "4 4.0 \n",
- "5 4.0 \n",
- "6 5.0 \n",
- "7 4.0 \n",
- "8 5.5 \n",
- "9 5.0 \n",
- "10 2.0 \n",
- "11 4.0 \n",
- "12 3.0 \n",
- "13 3.5 \n",
- "14 3.5 \n",
- "15 4.0 \n",
- "16 5.5 \n",
- "17 5.0 \n",
- "18 5.5 \n",
- "19 4.0 \n",
- "20 5.5 \n",
- "21 4.0 \n",
- "22 4.0 \n",
- "23 4.0 \n",
- "24 3.5 "
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "eval_frame = pd.DataFrame() # init frame to hold score ranking\n",
- "metric_list = ['acc', 'auc', 'f1', 'logloss', 'mse'] # metric to use for evaluation\n",
- "\n",
- "# create eval frame row-by-row\n",
- "for fold in sorted(scores_frame['fold'].unique()): # loop through folds \n",
- " for metric_name in metric_list: # loop through metrics\n",
- " \n",
- " # init row dict to hold each rows values\n",
- " row_dict = {'fold': fold,\n",
- " 'metric': metric_name}\n",
- " \n",
- " # cache known y values for fold\n",
- " fold_y = scores_frame.loc[scores_frame['fold'] == fold, y_name]\n",
- " \n",
- " for col_name in scores_frame.columns[2:]:\n",
- " \n",
- " # cache fold scores\n",
- " fold_scores = scores_frame.loc[scores_frame['fold'] == fold, col_name]\n",
- " \n",
- " # calculate evaluation metric for fold\n",
- " # with reasonable precision \n",
- " \n",
- " if metric_name == 'acc':\n",
- " row_dict[col_name] = np.round(max_acc(fold_y, fold_scores), ROUND)\n",
- " \n",
- " if metric_name == 'auc':\n",
- " row_dict[col_name] = np.round(roc_auc_score(fold_y, fold_scores), ROUND)\n",
- " \n",
- " if metric_name == 'f1':\n",
- " row_dict[col_name] = np.round(max_f1(fold_y, fold_scores), ROUND) \n",
- " \n",
- " if metric_name == 'logloss':\n",
- " row_dict[col_name] = np.round(log_loss(fold_y, fold_scores), ROUND)\n",
- " \n",
- " if metric_name == 'mse':\n",
- " row_dict[col_name] = np.round(mean_squared_error(fold_y, fold_scores), ROUND)\n",
- " \n",
- " # append row values to eval_frame\n",
- " eval_frame = eval_frame.append(row_dict, ignore_index=True)\n",
- "\n",
- "# init a temporary frame to hold rank information\n",
- "rank_names = [name + '_rank' for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]\n",
- "rank_frame = pd.DataFrame(columns=rank_names) \n",
- "\n",
- "# set columns to necessary order\n",
- "eval_frame = eval_frame[['fold', 'metric'] + [name for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]]\n",
- "\n",
- "# determine score ranks row-by-row\n",
- "for i in range(0, eval_frame.shape[0]):\n",
- " \n",
- " # get ranks for row based on metric\n",
- " metric_name = eval_frame.loc[i, 'metric']\n",
- " if metric_name in ['logloss', 'mse']:\n",
- " ranks = eval_frame.iloc[i, 2:].rank().values\n",
- " else:\n",
- " ranks = eval_frame.iloc[i, 2:].rank(ascending=False).values\n",
- " \n",
- " # create single-row frame and append to rank_frame\n",
- " row_frame = pd.DataFrame(ranks.reshape(1, ranks.shape[0]), columns=rank_names)\n",
- " rank_frame = rank_frame.append(row_frame, ignore_index=True)\n",
- " \n",
- " # house keeping\n",
- " del row_frame\n",
- "\n",
- "# merge ranks onto eval_frame\n",
- "eval_frame = pd.concat([eval_frame, rank_frame], axis=1)\n",
- "\n",
- "# house keeping\n",
- "del rank_frame\n",
- " \n",
- "eval_frame"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "37ed3b5f",
- "metadata": {},
- "source": [
- "#### Save `eval_frame` as CSV"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "aa89d862",
- "metadata": {},
- "outputs": [],
- "source": [
- "eval_frame.to_csv('model_eval_' + str(datetime.datetime.now().strftime(\"%Y_%m_%d_%H_%M_%S\") + '.csv'), \n",
- " index=False)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4525d3ea",
- "metadata": {},
- "source": [
- "#### Display simple ranked score list "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "f8ff5fa5",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "group2_rem_ebm_rank 1.66\n",
- "group5_rem_xgb2_rank 1.94\n",
- "group8_rem_ebm_rank 3.92\n",
- "group9_rem_xgb_rank 4.12\n",
- "ph_rem_ebm_rank 4.18\n",
- "group1_rem_ebm_rank 5.60\n",
- "group3_rem_piml_EBM_rank 6.58\n",
- "dtype: float64"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "eval_frame[[name for name in eval_frame.columns if name.endswith('rank')]].mean().sort_values()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.16"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}