From 99a8835bdb0ada408dd718491729ff7d6f715e08 Mon Sep 17 00:00:00 2001 From: ph_ Date: Thu, 29 Jun 2023 12:40:58 -0400 Subject: [PATCH] final remediated model evaluations --- assignments/eval.ipynb | 1415 +++++++++++++++++ .../model_eval_2023_06_28_21_00_17.csv | 26 + 2 files changed, 1441 insertions(+) create mode 100644 assignments/eval.ipynb create mode 100644 assignments/model_eval_2023_06_28_21_00_17.csv diff --git a/assignments/eval.ipynb b/assignments/eval.ipynb new file mode 100644 index 0000000..5a3a9ac --- /dev/null +++ b/assignments/eval.ipynb @@ -0,0 +1,1415 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f7efc033", + "metadata": {}, + "source": [ + "## License \n", + "\n", + "Copyright 2021-2023 Patrick Hall (jphall@gwu.edu)\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License.\n", + "\n", + "*DISCLAIMER*: This notebook is not legal or compliance advice." + ] + }, + { + "cell_type": "markdown", + "id": "aab60b41", + "metadata": {}, + "source": [ + "# Model Evaluation Notebook" + ] + }, + { + "cell_type": "markdown", + "id": "281af306", + "metadata": {}, + "source": [ + "#### Imports and inits" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fd180587", + "metadata": {}, + "outputs": [], + "source": [ + "import os # for directory and file manipulation\n", + "import numpy as np # for basic array manipulation\n", + "import pandas as pd # for dataframe manipulation\n", + "import datetime # for timestamp\n", + "\n", + "# for model eval\n", + "from sklearn.metrics import accuracy_score, f1_score, log_loss, mean_squared_error, roc_auc_score\n", + "\n", + "# global constants \n", + "ROUND = 3 # generally, insane precision is not needed \n", + "SEED = 12345 # seed for better reproducibility\n", + "\n", + "# set global random seed for better reproducibility\n", + "np.random.seed(SEED)" + ] + }, + { + "cell_type": "markdown", + "id": "eb2a39d4", + "metadata": {}, + "source": [ + "#### Set basic metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "98f640ed", + "metadata": {}, + "outputs": [], + "source": [ + "y_name = 'high_priced'\n", + "scores_dir = 'data/scores'" + ] + }, + { + "cell_type": "markdown", + "id": "cc8d83d0", + "metadata": {}, + "source": [ + "#### Read in score files " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "355c2b81", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
high_pricedfoldgroup1_rem_ebmgroup2_rem_ebmgroup2_rem_ebm2group3_rem_piml_EBMgroup3_rem_piml_EBM2group5_rem_xgb2group8_rem_ebmgroup9_rem_xgbph_rem_ebm
00.020.1187870.0805570.0805570.9203890.1367490.0783260.2238460.0817920.219429
10.010.0845060.0260010.0260010.9693010.0537510.0358250.0539260.1107020.053929
21.040.2103890.1949610.1949610.8142720.1823110.1953320.1435220.2040480.133863
30.010.0085290.0285560.0285560.9745590.0040650.0227650.0093710.0240380.014419
41.020.1899330.2082630.2082630.8029080.2111200.1930350.1511000.1702430.156047
....................................
198260.030.1636970.2283420.2283420.7922510.2093220.2351920.2167200.1814030.184214
198270.010.1149990.2539980.2539980.7629460.2067440.2358320.1614010.1594680.141663
198281.030.1413070.2133640.2133640.7474010.2466100.2087230.2428140.1381410.233266
198290.010.0077660.0021760.0021760.9964550.0002680.0187020.0056570.0345700.009914
198300.000.1639460.1854840.1854840.8114290.1778570.2150850.1678120.1777850.155447
\n", + "

19831 rows × 11 columns

\n", + "
" + ], + "text/plain": [ + " high_priced fold group1_rem_ebm group2_rem_ebm group2_rem_ebm2 \\\n", + "0 0.0 2 0.118787 0.080557 0.080557 \n", + "1 0.0 1 0.084506 0.026001 0.026001 \n", + "2 1.0 4 0.210389 0.194961 0.194961 \n", + "3 0.0 1 0.008529 0.028556 0.028556 \n", + "4 1.0 2 0.189933 0.208263 0.208263 \n", + "... ... ... ... ... ... \n", + "19826 0.0 3 0.163697 0.228342 0.228342 \n", + "19827 0.0 1 0.114999 0.253998 0.253998 \n", + "19828 1.0 3 0.141307 0.213364 0.213364 \n", + "19829 0.0 1 0.007766 0.002176 0.002176 \n", + "19830 0.0 0 0.163946 0.185484 0.185484 \n", + "\n", + " group3_rem_piml_EBM group3_rem_piml_EBM2 group5_rem_xgb2 \\\n", + "0 0.920389 0.136749 0.078326 \n", + "1 0.969301 0.053751 0.035825 \n", + "2 0.814272 0.182311 0.195332 \n", + "3 0.974559 0.004065 0.022765 \n", + "4 0.802908 0.211120 0.193035 \n", + "... ... ... ... \n", + "19826 0.792251 0.209322 0.235192 \n", + "19827 0.762946 0.206744 0.235832 \n", + "19828 0.747401 0.246610 0.208723 \n", + "19829 0.996455 0.000268 0.018702 \n", + "19830 0.811429 0.177857 0.215085 \n", + "\n", + " group8_rem_ebm group9_rem_xgb ph_rem_ebm \n", + "0 0.223846 0.081792 0.219429 \n", + "1 0.053926 0.110702 0.053929 \n", + "2 0.143522 0.204048 0.133863 \n", + "3 0.009371 0.024038 0.014419 \n", + "4 0.151100 0.170243 0.156047 \n", + "... ... ... ... \n", + "19826 0.216720 0.181403 0.184214 \n", + "19827 0.161401 0.159468 0.141663 \n", + "19828 0.242814 0.138141 0.233266 \n", + "19829 0.005657 0.034570 0.009914 \n", + "19830 0.167812 0.177785 0.155447 \n", + "\n", + "[19831 rows x 11 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# init score frame with known test y values\n", + "scores_frame = pd.read_csv(scores_dir + os.sep +'key.csv', index_col='Unnamed: 0')\n", + "\n", + "# create random folds in reproducible way\n", + "np.random.seed(SEED)\n", + "scores_frame['fold'] = np.random.choice(5, scores_frame.shape[0])\n", + "\n", + "# read in each score file in the directory as a new column \n", + "for file in sorted(os.listdir(scores_dir)):\n", + " if file != 'key.csv' and file.endswith('.csv'):\n", + " scores_frame[file[:-4]] = pd.read_csv(scores_dir + os.sep + file)['phat']\n", + "\n", + "# sanity check \n", + "scores_frame" + ] + }, + { + "cell_type": "markdown", + "id": "3e3cccda", + "metadata": {}, + "source": [ + "#### Utility function for max. accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2eb43506", + "metadata": {}, + "outputs": [], + "source": [ + "def max_acc(y, phat, res=0.01): \n", + "\n", + " \"\"\" Utility function for finding max. accuracy at some cutoff. \n", + " \n", + " :param y: Known y values.\n", + " :param phat: Model scores.\n", + " :param res: Resolution over which to search for max. accuracy, default 0.01.\n", + " :return: Max. accuracy for model scores.\n", + " \n", + " \"\"\"\n", + " \n", + " # init frame to store acc at different cutoffs\n", + " acc_frame = pd.DataFrame(columns=['cut', 'acc'])\n", + " \n", + " # copy known y and score values into a temporary frame\n", + " temp_df = pd.concat([y, phat], axis=1)\n", + " \n", + " # find accuracy at different cutoffs and store in acc_frame\n", + " for cut in np.arange(0, 1 + res, res):\n", + " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n", + " acc = accuracy_score(temp_df.iloc[:, 0], temp_df['decision'])\n", + " acc_frame = acc_frame.append({'cut': cut,\n", + " 'acc': acc},\n", + " ignore_index=True)\n", + "\n", + " # find max accurcay across all cutoffs\n", + " max_acc = acc_frame['acc'].max()\n", + " \n", + " # house keeping\n", + " del acc_frame, temp_df\n", + " \n", + " return max_acc" + ] + }, + { + "cell_type": "markdown", + "id": "b02c9651", + "metadata": {}, + "source": [ + "#### Utility function for max. F1" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fae3756b", + "metadata": {}, + "outputs": [], + "source": [ + "def max_f1(y, phat, res=0.01): \n", + " \n", + " \"\"\" Utility function for finding max. F1 at some cutoff. \n", + " \n", + " :param y: Known y values.\n", + " :param phat: Model scores.\n", + " :param res: Resolution over which to search for max. F1, default 0.01.\n", + " :return: Max. F1 for model scores.\n", + " \n", + " \"\"\"\n", + " \n", + " # init frame to store f1 at different cutoffs\n", + " f1_frame = pd.DataFrame(columns=['cut', 'f1'])\n", + " \n", + " # copy known y and score values into a temporary frame\n", + " temp_df = pd.concat([y, phat], axis=1)\n", + " \n", + " # find f1 at different cutoffs and store in acc_frame\n", + " for cut in np.arange(0, 1 + res, res):\n", + " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n", + " f1 = f1_score(temp_df.iloc[:, 0], temp_df['decision'])\n", + " f1_frame = f1_frame.append({'cut': cut,\n", + " 'f1': f1},\n", + " ignore_index=True)\n", + " \n", + " # find max f1 across all cutoffs\n", + " max_f1 = f1_frame['f1'].max()\n", + " \n", + " # house keeping\n", + " del f1_frame, temp_df\n", + " \n", + " return max_f1" + ] + }, + { + "cell_type": "markdown", + "id": "b447b732", + "metadata": {}, + "source": [ + "#### Rank all submitted scores " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "40fbe608", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
foldmetricgroup1_rem_ebmgroup2_rem_ebmgroup2_rem_ebm2group3_rem_piml_EBMgroup3_rem_piml_EBM2group5_rem_xgb2group8_rem_ebmgroup9_rem_xgbph_rem_ebmgroup1_rem_ebm_rankgroup2_rem_ebm_rankgroup2_rem_ebm2_rankgroup3_rem_piml_EBM_rankgroup3_rem_piml_EBM2_rankgroup5_rem_xgb2_rankgroup8_rem_ebm_rankgroup9_rem_xgb_rankph_rem_ebm_rank
00.0acc0.9000.9010.9010.9000.9010.9010.9010.9000.9018.03.53.58.03.53.53.58.03.5
10.0auc0.7810.8400.8400.1630.8210.8360.7930.7970.7918.01.51.59.04.03.06.05.07.0
20.0f10.3470.4050.4050.1820.3810.3920.3420.3570.3476.51.51.59.04.03.08.05.06.5
30.0logloss0.2800.2510.2513.2570.2620.2540.2740.2770.2758.01.51.59.04.03.05.07.06.0
40.0mse0.0820.0770.0770.7730.0780.0770.0810.0810.0818.02.02.09.04.02.06.06.06.0
51.0acc0.9060.9060.9060.9060.9060.9060.9060.9060.9065.05.05.05.05.05.05.05.05.0
61.0auc0.7670.8280.8280.1720.8100.8220.7740.7790.7728.01.51.59.04.03.06.05.07.0
71.0f10.3120.3680.3680.1720.3480.3600.3190.3290.3218.01.51.59.04.03.07.05.06.0
81.0logloss0.2720.2460.2463.2530.2580.2500.2700.2710.2727.51.51.59.04.03.05.06.07.5
91.0mse0.0790.0740.0740.7780.0770.0750.0790.0780.0797.01.51.59.04.03.07.05.07.0
102.0acc0.9080.9080.9080.9080.9080.9100.9080.9080.9096.06.06.06.06.01.06.06.02.0
112.0auc0.7590.8250.8250.1750.8150.8260.7810.7720.7808.02.52.59.04.01.05.07.06.0
122.0f10.3040.3720.3720.1690.3540.3710.3150.3200.3238.01.51.59.04.03.07.06.05.0
132.0logloss0.2710.2460.2463.2840.2510.2450.2640.2710.2647.52.52.59.04.01.05.57.55.5
142.0mse0.0780.0730.0730.7810.0740.0730.0760.0770.0768.02.02.09.04.02.05.57.05.5
153.0acc0.9030.9030.9030.9030.9030.9030.9030.9030.9035.05.05.05.05.05.05.05.05.0
163.0auc0.7720.8260.8260.1740.8090.8230.7750.7860.7727.51.51.59.04.03.06.05.07.5
173.0f10.3170.3710.3710.1770.3610.3650.3280.3430.3238.01.51.59.04.03.06.05.07.0
183.0logloss0.2760.2520.2523.2540.2620.2530.2750.2750.2767.51.51.59.04.03.05.55.57.5
193.0mse0.0810.0770.0770.7750.0790.0770.0800.0800.0808.02.02.09.04.02.06.06.06.0
204.0acc0.8950.8970.8970.8950.8950.8980.8950.8960.8957.02.52.57.07.01.07.04.07.0
214.0auc0.7540.8310.8310.1700.8180.8280.7850.7790.7828.01.51.59.04.03.05.07.06.0
224.0f10.3230.4010.4010.1900.4040.3970.3640.3540.3628.02.52.59.01.04.05.07.06.0
234.0logloss0.2960.2630.2633.2000.2730.2660.2860.2910.2878.01.51.59.04.03.05.07.06.0
244.0mse0.0870.0800.0800.7710.0820.0800.0840.0860.0848.02.02.09.04.02.05.57.05.5
\n", + "
" + ], + "text/plain": [ + " fold metric group1_rem_ebm group2_rem_ebm group2_rem_ebm2 \\\n", + "0 0.0 acc 0.900 0.901 0.901 \n", + "1 0.0 auc 0.781 0.840 0.840 \n", + "2 0.0 f1 0.347 0.405 0.405 \n", + "3 0.0 logloss 0.280 0.251 0.251 \n", + "4 0.0 mse 0.082 0.077 0.077 \n", + "5 1.0 acc 0.906 0.906 0.906 \n", + "6 1.0 auc 0.767 0.828 0.828 \n", + "7 1.0 f1 0.312 0.368 0.368 \n", + "8 1.0 logloss 0.272 0.246 0.246 \n", + "9 1.0 mse 0.079 0.074 0.074 \n", + "10 2.0 acc 0.908 0.908 0.908 \n", + "11 2.0 auc 0.759 0.825 0.825 \n", + "12 2.0 f1 0.304 0.372 0.372 \n", + "13 2.0 logloss 0.271 0.246 0.246 \n", + "14 2.0 mse 0.078 0.073 0.073 \n", + "15 3.0 acc 0.903 0.903 0.903 \n", + "16 3.0 auc 0.772 0.826 0.826 \n", + "17 3.0 f1 0.317 0.371 0.371 \n", + "18 3.0 logloss 0.276 0.252 0.252 \n", + "19 3.0 mse 0.081 0.077 0.077 \n", + "20 4.0 acc 0.895 0.897 0.897 \n", + "21 4.0 auc 0.754 0.831 0.831 \n", + "22 4.0 f1 0.323 0.401 0.401 \n", + "23 4.0 logloss 0.296 0.263 0.263 \n", + "24 4.0 mse 0.087 0.080 0.080 \n", + "\n", + " group3_rem_piml_EBM group3_rem_piml_EBM2 group5_rem_xgb2 \\\n", + "0 0.900 0.901 0.901 \n", + "1 0.163 0.821 0.836 \n", + "2 0.182 0.381 0.392 \n", + "3 3.257 0.262 0.254 \n", + "4 0.773 0.078 0.077 \n", + "5 0.906 0.906 0.906 \n", + "6 0.172 0.810 0.822 \n", + "7 0.172 0.348 0.360 \n", + "8 3.253 0.258 0.250 \n", + "9 0.778 0.077 0.075 \n", + "10 0.908 0.908 0.910 \n", + "11 0.175 0.815 0.826 \n", + "12 0.169 0.354 0.371 \n", + "13 3.284 0.251 0.245 \n", + "14 0.781 0.074 0.073 \n", + "15 0.903 0.903 0.903 \n", + "16 0.174 0.809 0.823 \n", + "17 0.177 0.361 0.365 \n", + "18 3.254 0.262 0.253 \n", + "19 0.775 0.079 0.077 \n", + "20 0.895 0.895 0.898 \n", + "21 0.170 0.818 0.828 \n", + "22 0.190 0.404 0.397 \n", + "23 3.200 0.273 0.266 \n", + "24 0.771 0.082 0.080 \n", + "\n", + " group8_rem_ebm group9_rem_xgb ph_rem_ebm group1_rem_ebm_rank \\\n", + "0 0.901 0.900 0.901 8.0 \n", + "1 0.793 0.797 0.791 8.0 \n", + "2 0.342 0.357 0.347 6.5 \n", + "3 0.274 0.277 0.275 8.0 \n", + "4 0.081 0.081 0.081 8.0 \n", + "5 0.906 0.906 0.906 5.0 \n", + "6 0.774 0.779 0.772 8.0 \n", + "7 0.319 0.329 0.321 8.0 \n", + "8 0.270 0.271 0.272 7.5 \n", + "9 0.079 0.078 0.079 7.0 \n", + "10 0.908 0.908 0.909 6.0 \n", + "11 0.781 0.772 0.780 8.0 \n", + "12 0.315 0.320 0.323 8.0 \n", + "13 0.264 0.271 0.264 7.5 \n", + "14 0.076 0.077 0.076 8.0 \n", + "15 0.903 0.903 0.903 5.0 \n", + "16 0.775 0.786 0.772 7.5 \n", + "17 0.328 0.343 0.323 8.0 \n", + "18 0.275 0.275 0.276 7.5 \n", + "19 0.080 0.080 0.080 8.0 \n", + "20 0.895 0.896 0.895 7.0 \n", + "21 0.785 0.779 0.782 8.0 \n", + "22 0.364 0.354 0.362 8.0 \n", + "23 0.286 0.291 0.287 8.0 \n", + "24 0.084 0.086 0.084 8.0 \n", + "\n", + " group2_rem_ebm_rank group2_rem_ebm2_rank group3_rem_piml_EBM_rank \\\n", + "0 3.5 3.5 8.0 \n", + "1 1.5 1.5 9.0 \n", + "2 1.5 1.5 9.0 \n", + "3 1.5 1.5 9.0 \n", + "4 2.0 2.0 9.0 \n", + "5 5.0 5.0 5.0 \n", + "6 1.5 1.5 9.0 \n", + "7 1.5 1.5 9.0 \n", + "8 1.5 1.5 9.0 \n", + "9 1.5 1.5 9.0 \n", + "10 6.0 6.0 6.0 \n", + "11 2.5 2.5 9.0 \n", + "12 1.5 1.5 9.0 \n", + "13 2.5 2.5 9.0 \n", + "14 2.0 2.0 9.0 \n", + "15 5.0 5.0 5.0 \n", + "16 1.5 1.5 9.0 \n", + "17 1.5 1.5 9.0 \n", + "18 1.5 1.5 9.0 \n", + "19 2.0 2.0 9.0 \n", + "20 2.5 2.5 7.0 \n", + "21 1.5 1.5 9.0 \n", + "22 2.5 2.5 9.0 \n", + "23 1.5 1.5 9.0 \n", + "24 2.0 2.0 9.0 \n", + "\n", + " group3_rem_piml_EBM2_rank group5_rem_xgb2_rank group8_rem_ebm_rank \\\n", + "0 3.5 3.5 3.5 \n", + "1 4.0 3.0 6.0 \n", + "2 4.0 3.0 8.0 \n", + "3 4.0 3.0 5.0 \n", + "4 4.0 2.0 6.0 \n", + "5 5.0 5.0 5.0 \n", + "6 4.0 3.0 6.0 \n", + "7 4.0 3.0 7.0 \n", + "8 4.0 3.0 5.0 \n", + "9 4.0 3.0 7.0 \n", + "10 6.0 1.0 6.0 \n", + "11 4.0 1.0 5.0 \n", + "12 4.0 3.0 7.0 \n", + "13 4.0 1.0 5.5 \n", + "14 4.0 2.0 5.5 \n", + "15 5.0 5.0 5.0 \n", + "16 4.0 3.0 6.0 \n", + "17 4.0 3.0 6.0 \n", + "18 4.0 3.0 5.5 \n", + "19 4.0 2.0 6.0 \n", + "20 7.0 1.0 7.0 \n", + "21 4.0 3.0 5.0 \n", + "22 1.0 4.0 5.0 \n", + "23 4.0 3.0 5.0 \n", + "24 4.0 2.0 5.5 \n", + "\n", + " group9_rem_xgb_rank ph_rem_ebm_rank \n", + "0 8.0 3.5 \n", + "1 5.0 7.0 \n", + "2 5.0 6.5 \n", + "3 7.0 6.0 \n", + "4 6.0 6.0 \n", + "5 5.0 5.0 \n", + "6 5.0 7.0 \n", + "7 5.0 6.0 \n", + "8 6.0 7.5 \n", + "9 5.0 7.0 \n", + "10 6.0 2.0 \n", + "11 7.0 6.0 \n", + "12 6.0 5.0 \n", + "13 7.5 5.5 \n", + "14 7.0 5.5 \n", + "15 5.0 5.0 \n", + "16 5.0 7.5 \n", + "17 5.0 7.0 \n", + "18 5.5 7.5 \n", + "19 6.0 6.0 \n", + "20 4.0 7.0 \n", + "21 7.0 6.0 \n", + "22 7.0 6.0 \n", + "23 7.0 6.0 \n", + "24 7.0 5.5 " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval_frame = pd.DataFrame() # init frame to hold score ranking\n", + "metric_list = ['acc', 'auc', 'f1', 'logloss', 'mse'] # metric to use for evaluation\n", + "\n", + "# create eval frame row-by-row\n", + "for fold in sorted(scores_frame['fold'].unique()): # loop through folds \n", + " for metric_name in metric_list: # loop through metrics\n", + " \n", + " # init row dict to hold each rows values\n", + " row_dict = {'fold': fold,\n", + " 'metric': metric_name}\n", + " \n", + " # cache known y values for fold\n", + " fold_y = scores_frame.loc[scores_frame['fold'] == fold, y_name]\n", + " \n", + " for col_name in scores_frame.columns[2:]:\n", + " \n", + " # cache fold scores\n", + " fold_scores = scores_frame.loc[scores_frame['fold'] == fold, col_name]\n", + " \n", + " # calculate evaluation metric for fold\n", + " # with reasonable precision \n", + " \n", + " if metric_name == 'acc':\n", + " row_dict[col_name] = np.round(max_acc(fold_y, fold_scores), ROUND)\n", + " \n", + " if metric_name == 'auc':\n", + " row_dict[col_name] = np.round(roc_auc_score(fold_y, fold_scores), ROUND)\n", + " \n", + " if metric_name == 'f1':\n", + " row_dict[col_name] = np.round(max_f1(fold_y, fold_scores), ROUND) \n", + " \n", + " if metric_name == 'logloss':\n", + " row_dict[col_name] = np.round(log_loss(fold_y, fold_scores), ROUND)\n", + " \n", + " if metric_name == 'mse':\n", + " row_dict[col_name] = np.round(mean_squared_error(fold_y, fold_scores), ROUND)\n", + " \n", + " # append row values to eval_frame\n", + " eval_frame = eval_frame.append(row_dict, ignore_index=True)\n", + "\n", + "# init a temporary frame to hold rank information\n", + "rank_names = [name + '_rank' for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]\n", + "rank_frame = pd.DataFrame(columns=rank_names) \n", + "\n", + "# set columns to necessary order\n", + "eval_frame = eval_frame[['fold', 'metric'] + [name for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]]\n", + "\n", + "# determine score ranks row-by-row\n", + "for i in range(0, eval_frame.shape[0]):\n", + " \n", + " # get ranks for row based on metric\n", + " metric_name = eval_frame.loc[i, 'metric']\n", + " if metric_name in ['logloss', 'mse']:\n", + " ranks = eval_frame.iloc[i, 2:].rank().values\n", + " else:\n", + " ranks = eval_frame.iloc[i, 2:].rank(ascending=False).values\n", + " \n", + " # create single-row frame and append to rank_frame\n", + " row_frame = pd.DataFrame(ranks.reshape(1, ranks.shape[0]), columns=rank_names)\n", + " rank_frame = rank_frame.append(row_frame, ignore_index=True)\n", + " \n", + " # house keeping\n", + " del row_frame\n", + "\n", + "# merge ranks onto eval_frame\n", + "eval_frame = pd.concat([eval_frame, rank_frame], axis=1)\n", + "\n", + "# house keeping\n", + "del rank_frame\n", + " \n", + "eval_frame" + ] + }, + { + "cell_type": "markdown", + "id": "37ed3b5f", + "metadata": {}, + "source": [ + "#### Save `eval_frame` as CSV" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "aa89d862", + "metadata": {}, + "outputs": [], + "source": [ + "eval_frame.to_csv('model_eval_' + str(datetime.datetime.now().strftime(\"%Y_%m_%d_%H_%M_%S\") + '.csv'), \n", + " index=False)" + ] + }, + { + "cell_type": "markdown", + "id": "4525d3ea", + "metadata": {}, + "source": [ + "#### Display simple ranked score list " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f8ff5fa5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "group2_rem_ebm_rank 2.28\n", + "group2_rem_ebm2_rank 2.28\n", + "group5_rem_xgb2_rank 2.74\n", + "group3_rem_piml_EBM2_rank 4.14\n", + "group8_rem_ebm_rank 5.74\n", + "group9_rem_xgb_rank 5.96\n", + "ph_rem_ebm_rank 5.96\n", + "group1_rem_ebm_rank 7.46\n", + "group3_rem_piml_EBM_rank 8.44\n", + "dtype: float64" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval_frame[[name for name in eval_frame.columns if name.endswith('rank')]].mean().sort_values()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/assignments/model_eval_2023_06_28_21_00_17.csv b/assignments/model_eval_2023_06_28_21_00_17.csv new file mode 100644 index 0000000..551556c --- /dev/null +++ b/assignments/model_eval_2023_06_28_21_00_17.csv @@ -0,0 +1,26 @@ +fold,metric,group1_rem_ebm,group2_rem_ebm,group2_rem_ebm2,group3_rem_piml_EBM,group3_rem_piml_EBM2,group5_rem_xgb2,group8_rem_ebm,group9_rem_xgb,ph_rem_ebm,group1_rem_ebm_rank,group2_rem_ebm_rank,group2_rem_ebm2_rank,group3_rem_piml_EBM_rank,group3_rem_piml_EBM2_rank,group5_rem_xgb2_rank,group8_rem_ebm_rank,group9_rem_xgb_rank,ph_rem_ebm_rank +0.0,acc,0.9,0.901,0.901,0.9,0.901,0.901,0.901,0.9,0.901,8.0,3.5,3.5,8.0,3.5,3.5,3.5,8.0,3.5 +0.0,auc,0.781,0.84,0.84,0.163,0.821,0.836,0.793,0.797,0.791,8.0,1.5,1.5,9.0,4.0,3.0,6.0,5.0,7.0 +0.0,f1,0.347,0.405,0.405,0.182,0.381,0.392,0.342,0.357,0.347,6.5,1.5,1.5,9.0,4.0,3.0,8.0,5.0,6.5 +0.0,logloss,0.28,0.251,0.251,3.257,0.262,0.254,0.274,0.277,0.275,8.0,1.5,1.5,9.0,4.0,3.0,5.0,7.0,6.0 +0.0,mse,0.082,0.077,0.077,0.773,0.078,0.077,0.081,0.081,0.081,8.0,2.0,2.0,9.0,4.0,2.0,6.0,6.0,6.0 +1.0,acc,0.906,0.906,0.906,0.906,0.906,0.906,0.906,0.906,0.906,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0 +1.0,auc,0.767,0.828,0.828,0.172,0.81,0.822,0.774,0.779,0.772,8.0,1.5,1.5,9.0,4.0,3.0,6.0,5.0,7.0 +1.0,f1,0.312,0.368,0.368,0.172,0.348,0.36,0.319,0.329,0.321,8.0,1.5,1.5,9.0,4.0,3.0,7.0,5.0,6.0 +1.0,logloss,0.272,0.246,0.246,3.253,0.258,0.25,0.27,0.271,0.272,7.5,1.5,1.5,9.0,4.0,3.0,5.0,6.0,7.5 +1.0,mse,0.079,0.074,0.074,0.778,0.077,0.075,0.079,0.078,0.079,7.0,1.5,1.5,9.0,4.0,3.0,7.0,5.0,7.0 +2.0,acc,0.908,0.908,0.908,0.908,0.908,0.91,0.908,0.908,0.909,6.0,6.0,6.0,6.0,6.0,1.0,6.0,6.0,2.0 +2.0,auc,0.759,0.825,0.825,0.175,0.815,0.826,0.781,0.772,0.78,8.0,2.5,2.5,9.0,4.0,1.0,5.0,7.0,6.0 +2.0,f1,0.304,0.372,0.372,0.169,0.354,0.371,0.315,0.32,0.323,8.0,1.5,1.5,9.0,4.0,3.0,7.0,6.0,5.0 +2.0,logloss,0.271,0.246,0.246,3.284,0.251,0.245,0.264,0.271,0.264,7.5,2.5,2.5,9.0,4.0,1.0,5.5,7.5,5.5 +2.0,mse,0.078,0.073,0.073,0.781,0.074,0.073,0.076,0.077,0.076,8.0,2.0,2.0,9.0,4.0,2.0,5.5,7.0,5.5 +3.0,acc,0.903,0.903,0.903,0.903,0.903,0.903,0.903,0.903,0.903,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0 +3.0,auc,0.772,0.826,0.826,0.174,0.809,0.823,0.775,0.786,0.772,7.5,1.5,1.5,9.0,4.0,3.0,6.0,5.0,7.5 +3.0,f1,0.317,0.371,0.371,0.177,0.361,0.365,0.328,0.343,0.323,8.0,1.5,1.5,9.0,4.0,3.0,6.0,5.0,7.0 +3.0,logloss,0.276,0.252,0.252,3.254,0.262,0.253,0.275,0.275,0.276,7.5,1.5,1.5,9.0,4.0,3.0,5.5,5.5,7.5 +3.0,mse,0.081,0.077,0.077,0.775,0.079,0.077,0.08,0.08,0.08,8.0,2.0,2.0,9.0,4.0,2.0,6.0,6.0,6.0 +4.0,acc,0.895,0.897,0.897,0.895,0.895,0.898,0.895,0.896,0.895,7.0,2.5,2.5,7.0,7.0,1.0,7.0,4.0,7.0 +4.0,auc,0.754,0.831,0.831,0.17,0.818,0.828,0.785,0.779,0.782,8.0,1.5,1.5,9.0,4.0,3.0,5.0,7.0,6.0 +4.0,f1,0.323,0.401,0.401,0.19,0.404,0.397,0.364,0.354,0.362,8.0,2.5,2.5,9.0,1.0,4.0,5.0,7.0,6.0 +4.0,logloss,0.296,0.263,0.263,3.2,0.273,0.266,0.286,0.291,0.287,8.0,1.5,1.5,9.0,4.0,3.0,5.0,7.0,6.0 +4.0,mse,0.087,0.08,0.08,0.771,0.082,0.08,0.084,0.086,0.084,8.0,2.0,2.0,9.0,4.0,2.0,5.5,7.0,5.5