diff --git a/assignments/eval.ipynb b/assignments/eval.ipynb deleted file mode 100644 index ce9dff2..0000000 --- a/assignments/eval.ipynb +++ /dev/null @@ -1,1245 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "f7efc033", - "metadata": {}, - "source": [ - "## License \n", - "\n", - "Copyright 2021-2023 Patrick Hall (jphall@gwu.edu)\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "you may not use this file except in compliance with the License.\n", - "You may obtain a copy of the License at\n", - "\n", - " http://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "Unless required by applicable law or agreed to in writing, software\n", - "distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "See the License for the specific language governing permissions and\n", - "limitations under the License.\n", - "\n", - "*DISCLAIMER*: This notebook is not legal or compliance advice." - ] - }, - { - "cell_type": "markdown", - "id": "aab60b41", - "metadata": {}, - "source": [ - "# Model Evaluation Notebook" - ] - }, - { - "cell_type": "markdown", - "id": "281af306", - "metadata": {}, - "source": [ - "#### Imports and inits" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "fd180587", - "metadata": {}, - "outputs": [], - "source": [ - "import os # for directory and file manipulation\n", - "import numpy as np # for basic array manipulation\n", - "import pandas as pd # for dataframe manipulation\n", - "import datetime # for timestamp\n", - "\n", - "# for model eval\n", - "from sklearn.metrics import accuracy_score, f1_score, log_loss, mean_squared_error, roc_auc_score\n", - "\n", - "# global constants \n", - "ROUND = 3 # generally, insane precision is not needed \n", - "SEED = 12345 # seed for better reproducibility\n", - "\n", - "# set global random seed for better reproducibility\n", - "np.random.seed(SEED)" - ] - }, - { - "cell_type": "markdown", - "id": "eb2a39d4", - "metadata": {}, - "source": [ - "#### Set basic metadata" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "98f640ed", - "metadata": {}, - "outputs": [], - "source": [ - "y_name = 'high_priced'\n", - "scores_dir = 'data/scores'" - ] - }, - { - "cell_type": "markdown", - "id": "cc8d83d0", - "metadata": {}, - "source": [ - "#### Read in score files " - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "355c2b81", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
high_pricedfoldgroup1_rem_ebmgroup2_rem_ebmgroup3_rem_piml_EBMgroup5_rem_xgb2group8_rem_ebmgroup9_rem_xgbph_rem_ebm
00.020.1187870.0805570.9203890.0783260.2238460.0817920.219429
10.010.0845060.0260010.9693010.0358250.0539260.1107020.053929
21.040.2103890.1949610.8142720.1953320.1435220.2040480.133863
30.010.0085290.0285560.9745590.0227650.0093710.0240380.014419
41.020.1899330.2082630.8029080.1930350.1511000.1702430.156047
..............................
198260.030.1636970.2283420.7922510.2351920.2167200.1814030.184214
198270.010.1149990.2539980.7629460.2358320.1614010.1594680.141663
198281.030.1413070.2133640.7474010.2087230.2428140.1381410.233266
198290.010.0077660.0021760.9964550.0187020.0056570.0345700.009914
198300.000.1639460.1854840.8114290.2150850.1678120.1777850.155447
\n", - "

19831 rows × 9 columns

\n", - "
" - ], - "text/plain": [ - " high_priced fold group1_rem_ebm group2_rem_ebm group3_rem_piml_EBM \\\n", - "0 0.0 2 0.118787 0.080557 0.920389 \n", - "1 0.0 1 0.084506 0.026001 0.969301 \n", - "2 1.0 4 0.210389 0.194961 0.814272 \n", - "3 0.0 1 0.008529 0.028556 0.974559 \n", - "4 1.0 2 0.189933 0.208263 0.802908 \n", - "... ... ... ... ... ... \n", - "19826 0.0 3 0.163697 0.228342 0.792251 \n", - "19827 0.0 1 0.114999 0.253998 0.762946 \n", - "19828 1.0 3 0.141307 0.213364 0.747401 \n", - "19829 0.0 1 0.007766 0.002176 0.996455 \n", - "19830 0.0 0 0.163946 0.185484 0.811429 \n", - "\n", - " group5_rem_xgb2 group8_rem_ebm group9_rem_xgb ph_rem_ebm \n", - "0 0.078326 0.223846 0.081792 0.219429 \n", - "1 0.035825 0.053926 0.110702 0.053929 \n", - "2 0.195332 0.143522 0.204048 0.133863 \n", - "3 0.022765 0.009371 0.024038 0.014419 \n", - "4 0.193035 0.151100 0.170243 0.156047 \n", - "... ... ... ... ... \n", - "19826 0.235192 0.216720 0.181403 0.184214 \n", - "19827 0.235832 0.161401 0.159468 0.141663 \n", - "19828 0.208723 0.242814 0.138141 0.233266 \n", - "19829 0.018702 0.005657 0.034570 0.009914 \n", - "19830 0.215085 0.167812 0.177785 0.155447 \n", - "\n", - "[19831 rows x 9 columns]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# init score frame with known test y values\n", - "scores_frame = pd.read_csv(scores_dir + os.sep +'key.csv', index_col='Unnamed: 0')\n", - "\n", - "# create random folds in reproducible way\n", - "np.random.seed(SEED)\n", - "scores_frame['fold'] = np.random.choice(5, scores_frame.shape[0])\n", - "\n", - "# read in each score file in the directory as a new column \n", - "for file in sorted(os.listdir(scores_dir)):\n", - " if file != 'key.csv' and file.endswith('.csv'):\n", - " scores_frame[file[:-4]] = pd.read_csv(scores_dir + os.sep + file)['phat']\n", - "\n", - "# sanity check \n", - "scores_frame" - ] - }, - { - "cell_type": "markdown", - "id": "3e3cccda", - "metadata": {}, - "source": [ - "#### Utility function for max. accuracy" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "2eb43506", - "metadata": {}, - "outputs": [], - "source": [ - "def max_acc(y, phat, res=0.01): \n", - "\n", - " \"\"\" Utility function for finding max. accuracy at some cutoff. \n", - " \n", - " :param y: Known y values.\n", - " :param phat: Model scores.\n", - " :param res: Resolution over which to search for max. accuracy, default 0.01.\n", - " :return: Max. accuracy for model scores.\n", - " \n", - " \"\"\"\n", - " \n", - " # init frame to store acc at different cutoffs\n", - " acc_frame = pd.DataFrame(columns=['cut', 'acc'])\n", - " \n", - " # copy known y and score values into a temporary frame\n", - " temp_df = pd.concat([y, phat], axis=1)\n", - " \n", - " # find accuracy at different cutoffs and store in acc_frame\n", - " for cut in np.arange(0, 1 + res, res):\n", - " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n", - " acc = accuracy_score(temp_df.iloc[:, 0], temp_df['decision'])\n", - " acc_frame = acc_frame.append({'cut': cut,\n", - " 'acc': acc},\n", - " ignore_index=True)\n", - "\n", - " # find max accurcay across all cutoffs\n", - " max_acc = acc_frame['acc'].max()\n", - " \n", - " # house keeping\n", - " del acc_frame, temp_df\n", - " \n", - " return max_acc" - ] - }, - { - "cell_type": "markdown", - "id": "b02c9651", - "metadata": {}, - "source": [ - "#### Utility function for max. F1" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "fae3756b", - "metadata": {}, - "outputs": [], - "source": [ - "def max_f1(y, phat, res=0.01): \n", - " \n", - " \"\"\" Utility function for finding max. F1 at some cutoff. \n", - " \n", - " :param y: Known y values.\n", - " :param phat: Model scores.\n", - " :param res: Resolution over which to search for max. F1, default 0.01.\n", - " :return: Max. F1 for model scores.\n", - " \n", - " \"\"\"\n", - " \n", - " # init frame to store f1 at different cutoffs\n", - " f1_frame = pd.DataFrame(columns=['cut', 'f1'])\n", - " \n", - " # copy known y and score values into a temporary frame\n", - " temp_df = pd.concat([y, phat], axis=1)\n", - " \n", - " # find f1 at different cutoffs and store in acc_frame\n", - " for cut in np.arange(0, 1 + res, res):\n", - " temp_df['decision'] = np.where(temp_df.iloc[:, 1] > cut, 1, 0)\n", - " f1 = f1_score(temp_df.iloc[:, 0], temp_df['decision'])\n", - " f1_frame = f1_frame.append({'cut': cut,\n", - " 'f1': f1},\n", - " ignore_index=True)\n", - " \n", - " # find max f1 across all cutoffs\n", - " max_f1 = f1_frame['f1'].max()\n", - " \n", - " # house keeping\n", - " del f1_frame, temp_df\n", - " \n", - " return max_f1" - ] - }, - { - "cell_type": "markdown", - "id": "b447b732", - "metadata": {}, - "source": [ - "#### Rank all submitted scores " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "40fbe608", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
foldmetricgroup1_rem_ebmgroup2_rem_ebmgroup3_rem_piml_EBMgroup5_rem_xgb2group8_rem_ebmgroup9_rem_xgbph_rem_ebmgroup1_rem_ebm_rankgroup2_rem_ebm_rankgroup3_rem_piml_EBM_rankgroup5_rem_xgb2_rankgroup8_rem_ebm_rankgroup9_rem_xgb_rankph_rem_ebm_rank
00.0acc0.9000.9010.9000.9010.9010.9000.9016.02.56.02.52.56.02.5
10.0auc0.7810.8400.1630.8360.7930.7970.7916.01.07.02.04.03.05.0
20.0f10.3470.4050.1820.3920.3420.3570.3474.51.07.02.06.03.04.5
30.0logloss0.2800.2513.2570.2540.2740.2770.2756.01.07.02.03.05.04.0
40.0mse0.0820.0770.7730.0770.0810.0810.0816.01.57.01.54.04.04.0
51.0acc0.9060.9060.9060.9060.9060.9060.9064.04.04.04.04.04.04.0
61.0auc0.7670.8280.1720.8220.7740.7790.7726.01.07.02.04.03.05.0
71.0f10.3120.3680.1720.3600.3190.3290.3216.01.07.02.05.03.04.0
81.0logloss0.2720.2463.2530.2500.2700.2710.2725.51.07.02.03.04.05.5
91.0mse0.0790.0740.7780.0750.0790.0780.0795.01.07.02.05.03.05.0
102.0acc0.9080.9080.9080.9100.9080.9080.9095.05.05.01.05.05.02.0
112.0auc0.7590.8250.1750.8260.7810.7720.7806.02.07.01.03.05.04.0
122.0f10.3040.3720.1690.3710.3150.3200.3236.01.07.02.05.04.03.0
132.0logloss0.2710.2463.2840.2450.2640.2710.2645.52.07.01.03.55.53.5
142.0mse0.0780.0730.7810.0730.0760.0770.0766.01.57.01.53.55.03.5
153.0acc0.9030.9030.9030.9030.9030.9030.9034.04.04.04.04.04.04.0
163.0auc0.7720.8260.1740.8230.7750.7860.7725.51.07.02.04.03.05.5
173.0f10.3170.3710.1770.3650.3280.3430.3236.01.07.02.04.03.05.0
183.0logloss0.2760.2523.2540.2530.2750.2750.2765.51.07.02.03.53.55.5
193.0mse0.0810.0770.7750.0770.0800.0800.0806.01.57.01.54.04.04.0
204.0acc0.8950.8970.8950.8980.8950.8960.8955.52.05.51.05.53.05.5
214.0auc0.7540.8310.1700.8280.7850.7790.7826.01.07.02.03.05.04.0
224.0f10.3230.4010.1900.3970.3640.3540.3626.01.07.02.03.05.04.0
234.0logloss0.2960.2633.2000.2660.2860.2910.2876.01.07.02.03.05.04.0
244.0mse0.0870.0800.7710.0800.0840.0860.0846.01.57.01.53.55.03.5
\n", - "
" - ], - "text/plain": [ - " fold metric group1_rem_ebm group2_rem_ebm group3_rem_piml_EBM \\\n", - "0 0.0 acc 0.900 0.901 0.900 \n", - "1 0.0 auc 0.781 0.840 0.163 \n", - "2 0.0 f1 0.347 0.405 0.182 \n", - "3 0.0 logloss 0.280 0.251 3.257 \n", - "4 0.0 mse 0.082 0.077 0.773 \n", - "5 1.0 acc 0.906 0.906 0.906 \n", - "6 1.0 auc 0.767 0.828 0.172 \n", - "7 1.0 f1 0.312 0.368 0.172 \n", - "8 1.0 logloss 0.272 0.246 3.253 \n", - "9 1.0 mse 0.079 0.074 0.778 \n", - "10 2.0 acc 0.908 0.908 0.908 \n", - "11 2.0 auc 0.759 0.825 0.175 \n", - "12 2.0 f1 0.304 0.372 0.169 \n", - "13 2.0 logloss 0.271 0.246 3.284 \n", - "14 2.0 mse 0.078 0.073 0.781 \n", - "15 3.0 acc 0.903 0.903 0.903 \n", - "16 3.0 auc 0.772 0.826 0.174 \n", - "17 3.0 f1 0.317 0.371 0.177 \n", - "18 3.0 logloss 0.276 0.252 3.254 \n", - "19 3.0 mse 0.081 0.077 0.775 \n", - "20 4.0 acc 0.895 0.897 0.895 \n", - "21 4.0 auc 0.754 0.831 0.170 \n", - "22 4.0 f1 0.323 0.401 0.190 \n", - "23 4.0 logloss 0.296 0.263 3.200 \n", - "24 4.0 mse 0.087 0.080 0.771 \n", - "\n", - " group5_rem_xgb2 group8_rem_ebm group9_rem_xgb ph_rem_ebm \\\n", - "0 0.901 0.901 0.900 0.901 \n", - "1 0.836 0.793 0.797 0.791 \n", - "2 0.392 0.342 0.357 0.347 \n", - "3 0.254 0.274 0.277 0.275 \n", - "4 0.077 0.081 0.081 0.081 \n", - "5 0.906 0.906 0.906 0.906 \n", - "6 0.822 0.774 0.779 0.772 \n", - "7 0.360 0.319 0.329 0.321 \n", - "8 0.250 0.270 0.271 0.272 \n", - "9 0.075 0.079 0.078 0.079 \n", - "10 0.910 0.908 0.908 0.909 \n", - "11 0.826 0.781 0.772 0.780 \n", - "12 0.371 0.315 0.320 0.323 \n", - "13 0.245 0.264 0.271 0.264 \n", - "14 0.073 0.076 0.077 0.076 \n", - "15 0.903 0.903 0.903 0.903 \n", - "16 0.823 0.775 0.786 0.772 \n", - "17 0.365 0.328 0.343 0.323 \n", - "18 0.253 0.275 0.275 0.276 \n", - "19 0.077 0.080 0.080 0.080 \n", - "20 0.898 0.895 0.896 0.895 \n", - "21 0.828 0.785 0.779 0.782 \n", - "22 0.397 0.364 0.354 0.362 \n", - "23 0.266 0.286 0.291 0.287 \n", - "24 0.080 0.084 0.086 0.084 \n", - "\n", - " group1_rem_ebm_rank group2_rem_ebm_rank group3_rem_piml_EBM_rank \\\n", - "0 6.0 2.5 6.0 \n", - "1 6.0 1.0 7.0 \n", - "2 4.5 1.0 7.0 \n", - "3 6.0 1.0 7.0 \n", - "4 6.0 1.5 7.0 \n", - "5 4.0 4.0 4.0 \n", - "6 6.0 1.0 7.0 \n", - "7 6.0 1.0 7.0 \n", - "8 5.5 1.0 7.0 \n", - "9 5.0 1.0 7.0 \n", - "10 5.0 5.0 5.0 \n", - "11 6.0 2.0 7.0 \n", - "12 6.0 1.0 7.0 \n", - "13 5.5 2.0 7.0 \n", - "14 6.0 1.5 7.0 \n", - "15 4.0 4.0 4.0 \n", - "16 5.5 1.0 7.0 \n", - "17 6.0 1.0 7.0 \n", - "18 5.5 1.0 7.0 \n", - "19 6.0 1.5 7.0 \n", - "20 5.5 2.0 5.5 \n", - "21 6.0 1.0 7.0 \n", - "22 6.0 1.0 7.0 \n", - "23 6.0 1.0 7.0 \n", - "24 6.0 1.5 7.0 \n", - "\n", - " group5_rem_xgb2_rank group8_rem_ebm_rank group9_rem_xgb_rank \\\n", - "0 2.5 2.5 6.0 \n", - "1 2.0 4.0 3.0 \n", - "2 2.0 6.0 3.0 \n", - "3 2.0 3.0 5.0 \n", - "4 1.5 4.0 4.0 \n", - "5 4.0 4.0 4.0 \n", - "6 2.0 4.0 3.0 \n", - "7 2.0 5.0 3.0 \n", - "8 2.0 3.0 4.0 \n", - "9 2.0 5.0 3.0 \n", - "10 1.0 5.0 5.0 \n", - "11 1.0 3.0 5.0 \n", - "12 2.0 5.0 4.0 \n", - "13 1.0 3.5 5.5 \n", - "14 1.5 3.5 5.0 \n", - "15 4.0 4.0 4.0 \n", - "16 2.0 4.0 3.0 \n", - "17 2.0 4.0 3.0 \n", - "18 2.0 3.5 3.5 \n", - "19 1.5 4.0 4.0 \n", - "20 1.0 5.5 3.0 \n", - "21 2.0 3.0 5.0 \n", - "22 2.0 3.0 5.0 \n", - "23 2.0 3.0 5.0 \n", - "24 1.5 3.5 5.0 \n", - "\n", - " ph_rem_ebm_rank \n", - "0 2.5 \n", - "1 5.0 \n", - "2 4.5 \n", - "3 4.0 \n", - "4 4.0 \n", - "5 4.0 \n", - "6 5.0 \n", - "7 4.0 \n", - "8 5.5 \n", - "9 5.0 \n", - "10 2.0 \n", - "11 4.0 \n", - "12 3.0 \n", - "13 3.5 \n", - "14 3.5 \n", - "15 4.0 \n", - "16 5.5 \n", - "17 5.0 \n", - "18 5.5 \n", - "19 4.0 \n", - "20 5.5 \n", - "21 4.0 \n", - "22 4.0 \n", - "23 4.0 \n", - "24 3.5 " - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval_frame = pd.DataFrame() # init frame to hold score ranking\n", - "metric_list = ['acc', 'auc', 'f1', 'logloss', 'mse'] # metric to use for evaluation\n", - "\n", - "# create eval frame row-by-row\n", - "for fold in sorted(scores_frame['fold'].unique()): # loop through folds \n", - " for metric_name in metric_list: # loop through metrics\n", - " \n", - " # init row dict to hold each rows values\n", - " row_dict = {'fold': fold,\n", - " 'metric': metric_name}\n", - " \n", - " # cache known y values for fold\n", - " fold_y = scores_frame.loc[scores_frame['fold'] == fold, y_name]\n", - " \n", - " for col_name in scores_frame.columns[2:]:\n", - " \n", - " # cache fold scores\n", - " fold_scores = scores_frame.loc[scores_frame['fold'] == fold, col_name]\n", - " \n", - " # calculate evaluation metric for fold\n", - " # with reasonable precision \n", - " \n", - " if metric_name == 'acc':\n", - " row_dict[col_name] = np.round(max_acc(fold_y, fold_scores), ROUND)\n", - " \n", - " if metric_name == 'auc':\n", - " row_dict[col_name] = np.round(roc_auc_score(fold_y, fold_scores), ROUND)\n", - " \n", - " if metric_name == 'f1':\n", - " row_dict[col_name] = np.round(max_f1(fold_y, fold_scores), ROUND) \n", - " \n", - " if metric_name == 'logloss':\n", - " row_dict[col_name] = np.round(log_loss(fold_y, fold_scores), ROUND)\n", - " \n", - " if metric_name == 'mse':\n", - " row_dict[col_name] = np.round(mean_squared_error(fold_y, fold_scores), ROUND)\n", - " \n", - " # append row values to eval_frame\n", - " eval_frame = eval_frame.append(row_dict, ignore_index=True)\n", - "\n", - "# init a temporary frame to hold rank information\n", - "rank_names = [name + '_rank' for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]\n", - "rank_frame = pd.DataFrame(columns=rank_names) \n", - "\n", - "# set columns to necessary order\n", - "eval_frame = eval_frame[['fold', 'metric'] + [name for name in sorted(eval_frame.columns) if name not in ['fold', 'metric']]]\n", - "\n", - "# determine score ranks row-by-row\n", - "for i in range(0, eval_frame.shape[0]):\n", - " \n", - " # get ranks for row based on metric\n", - " metric_name = eval_frame.loc[i, 'metric']\n", - " if metric_name in ['logloss', 'mse']:\n", - " ranks = eval_frame.iloc[i, 2:].rank().values\n", - " else:\n", - " ranks = eval_frame.iloc[i, 2:].rank(ascending=False).values\n", - " \n", - " # create single-row frame and append to rank_frame\n", - " row_frame = pd.DataFrame(ranks.reshape(1, ranks.shape[0]), columns=rank_names)\n", - " rank_frame = rank_frame.append(row_frame, ignore_index=True)\n", - " \n", - " # house keeping\n", - " del row_frame\n", - "\n", - "# merge ranks onto eval_frame\n", - "eval_frame = pd.concat([eval_frame, rank_frame], axis=1)\n", - "\n", - "# house keeping\n", - "del rank_frame\n", - " \n", - "eval_frame" - ] - }, - { - "cell_type": "markdown", - "id": "37ed3b5f", - "metadata": {}, - "source": [ - "#### Save `eval_frame` as CSV" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "aa89d862", - "metadata": {}, - "outputs": [], - "source": [ - "eval_frame.to_csv('model_eval_' + str(datetime.datetime.now().strftime(\"%Y_%m_%d_%H_%M_%S\") + '.csv'), \n", - " index=False)" - ] - }, - { - "cell_type": "markdown", - "id": "4525d3ea", - "metadata": {}, - "source": [ - "#### Display simple ranked score list " - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "f8ff5fa5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "group2_rem_ebm_rank 1.66\n", - "group5_rem_xgb2_rank 1.94\n", - "group8_rem_ebm_rank 3.92\n", - "group9_rem_xgb_rank 4.12\n", - "ph_rem_ebm_rank 4.18\n", - "group1_rem_ebm_rank 5.60\n", - "group3_rem_piml_EBM_rank 6.58\n", - "dtype: float64" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "eval_frame[[name for name in eval_frame.columns if name.endswith('rank')]].mean().sort_values()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.16" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}