From c3cda73246489635164a9342967b7c9c4af2aebd Mon Sep 17 00:00:00 2001 From: "Egor.Kraev" Date: Mon, 23 Sep 2024 16:59:16 +0100 Subject: [PATCH] Add nice howto notebook --- ...through regression on Shapley values.ipynb | 1517 +++++++++++++++++ shap_select/select.py | 14 +- tests/test_regression.py | 2 +- 3 files changed, 1528 insertions(+), 5 deletions(-) create mode 100644 docs/Quick feature selection through regression on Shapley values.ipynb diff --git a/docs/Quick feature selection through regression on Shapley values.ipynb b/docs/Quick feature selection through regression on Shapley values.ipynb new file mode 100644 index 0000000..777d7e6 --- /dev/null +++ b/docs/Quick feature selection through regression on Shapley values.ipynb @@ -0,0 +1,1517 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b8684862", + "metadata": {}, + "source": [ + "# Quick feature selection through regression on Shapley values\n", + "\n", + "Feature selection for tabular models is a hard problem, and most solutions proposed for it are computationally expensive. Here we show a heuristic method that is quite computationally efficient, due to the fact that computing Shapley values on tree-based models (such as XGBoost, LightGBM, or CatBoost) is quite quick. \n", + "\n", + "For those who haven't come across them before, Shapley values are simply a way of decomposing a model's output into contributions from the individual feature values, with the nice property that all the features' contributions are guaranteed to add up to the model output. \n", + "\n", + "The process goes as follows: first, you split your dataset into a training and a validation set, and train a tree-based model on the training set, using all the available features, ideally with early stopping. If you already have a model thus fitted, you can just use that instead.\n", + "\n", + "In the second step, you calculate the Shapley values of all the features for that model, on the validation set. And now comes the fun part: for every data point in the validation set the Shapley values add up, by construction, to the model output for that data point. \n", + "\n", + "Now you are in linear country. As the next step, you run a regression of the target value on the shapley values of the features, on the validation set. If the model was perfect (model output identical to target) all the regression coefficients would be equal to 1.0. In practice, that will not be the case, and the coefficients of irrelevant features end up either being statistically insignificant (because the contributions of those features don't, on average, bring the model output closer to the target on the validation set), or negative, indicating that their presence is actually harming validation set performance.\n", + "\n", + "So our algorithm recommends first discarding all features with negative coefficients, then ranking the rest according to their statistical significance, and choosing some significance threshold (default 5%) getting below which will make us keep the feature. \n", + "\n", + "Here's an example on synthetic data:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "348c2468", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "np.random.seed(42)\n", + "n_samples = 100000\n", + "\n", + "# Create 9 normally distributed features\n", + "X = pd.DataFrame(\n", + " {\n", + " \"x1\": np.random.normal(size=n_samples),\n", + " \"x2\": np.random.normal(size=n_samples),\n", + " \"x3\": np.random.normal(size=n_samples),\n", + " \"x4\": np.random.normal(size=n_samples),\n", + " \"x5\": np.random.normal(size=n_samples),\n", + " \"x6\": np.random.normal(size=n_samples),\n", + " \"x7\": np.random.normal(size=n_samples),\n", + " \"x8\": np.random.normal(size=n_samples),\n", + " \"x9\": np.random.normal(size=n_samples),\n", + " }\n", + ")\n", + "\n", + "# Make all the features positive-ish\n", + "X += 3\n", + "\n", + "# Define the target based on the formula y = x1 + x2*x3 + x4*x5*x6\n", + "y = (\n", + " 3 * X[\"x1\"]\n", + " + X[\"x2\"] * X[\"x3\"]\n", + " + X[\"x4\"] * X[\"x5\"] * X[\"x6\"]\n", + " + 10 * np.random.normal(size=n_samples) # lots of noise\n", + ")\n", + "X[\"x6\"] *= 0.1\n", + "X[\"x6\"] += np.random.normal(size=n_samples)\n", + "\n", + "# Split the dataset into training and validation sets (both with 10K rows)\n", + "X_train, X_val, y_train, y_val = train_test_split(\n", + " X, y, test_size=0.1, random_state=42\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fb991c51", + "metadata": {}, + "source": [ + "Let's train, for example, an xgboost model on the training set:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ec03ff2c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0]\tvalid-rmse:17.78711\n", + "[1]\tvalid-rmse:16.44843\n", + "[2]\tvalid-rmse:15.64895\n", + "[3]\tvalid-rmse:15.19588\n", + "[4]\tvalid-rmse:14.92683\n", + "[5]\tvalid-rmse:14.75290\n", + "[6]\tvalid-rmse:14.65225\n", + "[7]\tvalid-rmse:14.56790\n", + "[8]\tvalid-rmse:14.50784\n", + "[9]\tvalid-rmse:14.46584\n", + "[10]\tvalid-rmse:14.43859\n", + "[11]\tvalid-rmse:14.42790\n", + "[12]\tvalid-rmse:14.41093\n", + "[13]\tvalid-rmse:14.39674\n", + "[14]\tvalid-rmse:14.38603\n", + "[15]\tvalid-rmse:14.38173\n", + "[16]\tvalid-rmse:14.37627\n", + "[17]\tvalid-rmse:14.37386\n", + "[18]\tvalid-rmse:14.36957\n", + "[19]\tvalid-rmse:14.36874\n", + "[20]\tvalid-rmse:14.36958\n", + "[21]\tvalid-rmse:14.37481\n", + "[22]\tvalid-rmse:14.37414\n", + "[23]\tvalid-rmse:14.37449\n", + "[24]\tvalid-rmse:14.37473\n", + "[25]\tvalid-rmse:14.37843\n", + "[26]\tvalid-rmse:14.38056\n", + "[27]\tvalid-rmse:14.38592\n", + "[28]\tvalid-rmse:14.39205\n", + "[29]\tvalid-rmse:14.39171\n", + "[30]\tvalid-rmse:14.38889\n", + "[31]\tvalid-rmse:14.39872\n", + "[32]\tvalid-rmse:14.40221\n", + "[33]\tvalid-rmse:14.40517\n", + "[34]\tvalid-rmse:14.41196\n", + "[35]\tvalid-rmse:14.41776\n", + "[36]\tvalid-rmse:14.41830\n", + "[37]\tvalid-rmse:14.42190\n", + "[38]\tvalid-rmse:14.42338\n", + "[39]\tvalid-rmse:14.42358\n", + "[40]\tvalid-rmse:14.42555\n", + "[41]\tvalid-rmse:14.42859\n", + "[42]\tvalid-rmse:14.43496\n", + "[43]\tvalid-rmse:14.43931\n", + "[44]\tvalid-rmse:14.44010\n", + "[45]\tvalid-rmse:14.44360\n", + "[46]\tvalid-rmse:14.44819\n", + "[47]\tvalid-rmse:14.45216\n", + "[48]\tvalid-rmse:14.45540\n", + "[49]\tvalid-rmse:14.46038\n", + "[50]\tvalid-rmse:14.46093\n", + "[51]\tvalid-rmse:14.46455\n", + "[52]\tvalid-rmse:14.46794\n", + "[53]\tvalid-rmse:14.47515\n", + "[54]\tvalid-rmse:14.48102\n", + "[55]\tvalid-rmse:14.48300\n", + "[56]\tvalid-rmse:14.48801\n", + "[57]\tvalid-rmse:14.49156\n", + "[58]\tvalid-rmse:14.48867\n", + "[59]\tvalid-rmse:14.49315\n", + "[60]\tvalid-rmse:14.49491\n", + "[61]\tvalid-rmse:14.49620\n", + "[62]\tvalid-rmse:14.50005\n", + "[63]\tvalid-rmse:14.50803\n", + "[64]\tvalid-rmse:14.51442\n", + "[65]\tvalid-rmse:14.51705\n", + "[66]\tvalid-rmse:14.52365\n", + "[67]\tvalid-rmse:14.52792\n", + "[68]\tvalid-rmse:14.53296\n" + ] + } + ], + "source": [ + "import xgboost as xgb\n", + "\n", + "dtrain = xgb.DMatrix(X_train, label=y_train)\n", + "dval = xgb.DMatrix(X_val, label=y_val)\n", + "params = {\n", + " \"objective\": \"reg:squarederror\",\n", + " \"eval_metric\": \"rmse\",\n", + " \"verbosity\": 0,\n", + " }\n", + "\n", + "model = xgb.train(\n", + " params, dtrain, num_boost_round=1000, evals= [(dval, \"valid\")], early_stopping_rounds=50\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0fbd335f", + "metadata": {}, + "source": [ + "Now let's generate the feature significance scores. The final column shows whether we suggest to select the feature; -1 means feature is rejected because of a negative regression coefficient, 0 means it's rejected because of not passing the significance threshold." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8f403fc5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 feature namecoefficientstderrstat.significancet-valuecloseness to 1.0Selected
0x51.0520300.0520520.00000020.2112990.0520301
1x40.9524160.0520020.00000018.3151440.0475841
2x31.0981540.1606500.0000006.8356900.0981541
3x21.0448420.1618120.0000006.4571400.0448421
4x10.9172420.1658500.0000005.5305560.0827581
5x61.4979830.6265440.0168272.3908680.4979831
6x72.8655083.1800170.3675580.9010981.8655080
7x81.9336323.4332080.5733020.5632140.9336320
8x9-4.5370982.8219050.107908-1.6078145.537098-1
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os, sys\n", + "\n", + "try:\n", + " from shap_select import score_features\n", + "except ModuleNotFoundError:\n", + " # If you're running shap_select from source\n", + " root = os.path.realpath(\"..\")\n", + " sys.path.append(root)\n", + " from shap_select import score_features\n", + "\n", + "selected_features_df = score_features(\n", + " model, X_val, X_val.columns.tolist(), y_val, task=\"regression\", threshold=0.05\n", + ")\n", + "\n", + "# Let's color the output prettily\n", + "styled_df = selected_features_df.style.background_gradient(\n", + " cmap='coolwarm', subset=pd.IndexSlice[:, ['coefficient', \n", + " 'stderr', \n", + " 'stat.significance', \n", + " 't-value', \n", + " 'closeness to 1.0', \n", + " 'Selected']]\n", + ")\n", + "styled_df" + ] + }, + { + "cell_type": "markdown", + "id": "6e6f6f51", + "metadata": {}, + "source": [ + "## What about classifier models?\n", + "You'll be happy to hear that the above approach works just fine on the classifier models. There is a slight difference under the hood, described below, but both the function call, and the interpretation of the output, work exactly the same. \n", + "\n", + "### Technical details for classifier models\n", + "The `shap` package automatically regcognizes whether it's given a classifier model, and in that case, calculates the shap values for log odds of a particular outcome.\n", + "\n", + "In the case of a binary classifier, this means that we now have to run a logistic, rather than a linear regression, and then proceed exactly like before with interpreting the coefficients and significances.\n", + "\n", + "In the case of a multiclass classifier, we get shapley values for each value of the target; we run a binary regression for each and then for each coefficient take the largest t-value across these regresssions, and calculate the statistical significance from that. Finally, to avoid the data mining effect of multiple tests, we apply the Bonferroni correction by multiplying the resulting significance by the number of classes; this way, you can compare that value to the original threshold value. \n", + "\n", + "Below is an example of a multiclass classifier.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1412da7f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0]\tvalid-mlogloss:0.78966\n", + "[1]\tvalid-mlogloss:0.60695\n", + "[2]\tvalid-mlogloss:0.48586\n", + "[3]\tvalid-mlogloss:0.40006\n", + "[4]\tvalid-mlogloss:0.33654\n", + "[5]\tvalid-mlogloss:0.28842\n", + "[6]\tvalid-mlogloss:0.25138\n", + "[7]\tvalid-mlogloss:0.22226\n", + "[8]\tvalid-mlogloss:0.19882\n", + "[9]\tvalid-mlogloss:0.17992\n", + "[10]\tvalid-mlogloss:0.16560\n", + "[11]\tvalid-mlogloss:0.15291\n", + "[12]\tvalid-mlogloss:0.14259\n", + "[13]\tvalid-mlogloss:0.13417\n", + "[14]\tvalid-mlogloss:0.12714\n", + "[15]\tvalid-mlogloss:0.12163\n", + "[16]\tvalid-mlogloss:0.11609\n", + "[17]\tvalid-mlogloss:0.11109\n", + "[18]\tvalid-mlogloss:0.10706\n", + "[19]\tvalid-mlogloss:0.10308\n", + "[20]\tvalid-mlogloss:0.09909\n", + "[21]\tvalid-mlogloss:0.09610\n", + "[22]\tvalid-mlogloss:0.09318\n", + "[23]\tvalid-mlogloss:0.09023\n", + "[24]\tvalid-mlogloss:0.08807\n", + "[25]\tvalid-mlogloss:0.08563\n", + "[26]\tvalid-mlogloss:0.08399\n", + "[27]\tvalid-mlogloss:0.08230\n", + "[28]\tvalid-mlogloss:0.08096\n", + "[29]\tvalid-mlogloss:0.07934\n", + "[30]\tvalid-mlogloss:0.07750\n", + "[31]\tvalid-mlogloss:0.07608\n", + "[32]\tvalid-mlogloss:0.07493\n", + "[33]\tvalid-mlogloss:0.07354\n", + "[34]\tvalid-mlogloss:0.07225\n", + "[35]\tvalid-mlogloss:0.07103\n", + "[36]\tvalid-mlogloss:0.06991\n", + "[37]\tvalid-mlogloss:0.06901\n", + "[38]\tvalid-mlogloss:0.06810\n", + "[39]\tvalid-mlogloss:0.06741\n", + "[40]\tvalid-mlogloss:0.06636\n", + "[41]\tvalid-mlogloss:0.06560\n", + "[42]\tvalid-mlogloss:0.06488\n", + "[43]\tvalid-mlogloss:0.06392\n", + "[44]\tvalid-mlogloss:0.06308\n", + "[45]\tvalid-mlogloss:0.06232\n", + "[46]\tvalid-mlogloss:0.06155\n", + "[47]\tvalid-mlogloss:0.06099\n", + "[48]\tvalid-mlogloss:0.06039\n", + "[49]\tvalid-mlogloss:0.05985\n", + "[50]\tvalid-mlogloss:0.05917\n", + "[51]\tvalid-mlogloss:0.05860\n", + "[52]\tvalid-mlogloss:0.05800\n", + "[53]\tvalid-mlogloss:0.05757\n", + "[54]\tvalid-mlogloss:0.05691\n", + "[55]\tvalid-mlogloss:0.05645\n", + "[56]\tvalid-mlogloss:0.05576\n", + "[57]\tvalid-mlogloss:0.05521\n", + "[58]\tvalid-mlogloss:0.05475\n", + "[59]\tvalid-mlogloss:0.05439\n", + "[60]\tvalid-mlogloss:0.05391\n", + "[61]\tvalid-mlogloss:0.05366\n", + "[62]\tvalid-mlogloss:0.05341\n", + "[63]\tvalid-mlogloss:0.05308\n", + "[64]\tvalid-mlogloss:0.05264\n", + "[65]\tvalid-mlogloss:0.05230\n", + "[66]\tvalid-mlogloss:0.05187\n", + "[67]\tvalid-mlogloss:0.05153\n", + "[68]\tvalid-mlogloss:0.05135\n", + "[69]\tvalid-mlogloss:0.05105\n", + "[70]\tvalid-mlogloss:0.05064\n", + "[71]\tvalid-mlogloss:0.05037\n", + "[72]\tvalid-mlogloss:0.05008\n", + "[73]\tvalid-mlogloss:0.04967\n", + "[74]\tvalid-mlogloss:0.04939\n", + "[75]\tvalid-mlogloss:0.04920\n", + "[76]\tvalid-mlogloss:0.04898\n", + "[77]\tvalid-mlogloss:0.04861\n", + "[78]\tvalid-mlogloss:0.04828\n", + "[79]\tvalid-mlogloss:0.04803\n", + "[80]\tvalid-mlogloss:0.04779\n", + "[81]\tvalid-mlogloss:0.04741\n", + "[82]\tvalid-mlogloss:0.04711\n", + "[83]\tvalid-mlogloss:0.04689\n", + "[84]\tvalid-mlogloss:0.04659\n", + "[85]\tvalid-mlogloss:0.04630\n", + "[86]\tvalid-mlogloss:0.04615\n", + "[87]\tvalid-mlogloss:0.04597\n", + "[88]\tvalid-mlogloss:0.04580\n", + "[89]\tvalid-mlogloss:0.04563\n", + "[90]\tvalid-mlogloss:0.04547\n", + "[91]\tvalid-mlogloss:0.04522\n", + "[92]\tvalid-mlogloss:0.04514\n", + "[93]\tvalid-mlogloss:0.04479\n", + "[94]\tvalid-mlogloss:0.04464\n", + "[95]\tvalid-mlogloss:0.04441\n", + "[96]\tvalid-mlogloss:0.04428\n", + "[97]\tvalid-mlogloss:0.04415\n", + "[98]\tvalid-mlogloss:0.04398\n", + "[99]\tvalid-mlogloss:0.04376\n", + "[100]\tvalid-mlogloss:0.04360\n", + "[101]\tvalid-mlogloss:0.04344\n", + "[102]\tvalid-mlogloss:0.04332\n", + "[103]\tvalid-mlogloss:0.04308\n", + "[104]\tvalid-mlogloss:0.04300\n", + "[105]\tvalid-mlogloss:0.04283\n", + "[106]\tvalid-mlogloss:0.04268\n", + "[107]\tvalid-mlogloss:0.04245\n", + "[108]\tvalid-mlogloss:0.04237\n", + "[109]\tvalid-mlogloss:0.04230\n", + "[110]\tvalid-mlogloss:0.04219\n", + "[111]\tvalid-mlogloss:0.04209\n", + "[112]\tvalid-mlogloss:0.04200\n", + "[113]\tvalid-mlogloss:0.04184\n", + "[114]\tvalid-mlogloss:0.04164\n", + "[115]\tvalid-mlogloss:0.04151\n", + "[116]\tvalid-mlogloss:0.04123\n", + "[117]\tvalid-mlogloss:0.04102\n", + "[118]\tvalid-mlogloss:0.04090\n", + "[119]\tvalid-mlogloss:0.04084\n", + "[120]\tvalid-mlogloss:0.04073\n", + "[121]\tvalid-mlogloss:0.04057\n", + "[122]\tvalid-mlogloss:0.04041\n", + "[123]\tvalid-mlogloss:0.04028\n", + "[124]\tvalid-mlogloss:0.04016\n", + "[125]\tvalid-mlogloss:0.04005\n", + "[126]\tvalid-mlogloss:0.04000\n", + "[127]\tvalid-mlogloss:0.04000\n", + "[128]\tvalid-mlogloss:0.03991\n", + "[129]\tvalid-mlogloss:0.03964\n", + "[130]\tvalid-mlogloss:0.03946\n", + "[131]\tvalid-mlogloss:0.03939\n", + "[132]\tvalid-mlogloss:0.03937\n", + "[133]\tvalid-mlogloss:0.03936\n", + "[134]\tvalid-mlogloss:0.03929\n", + "[135]\tvalid-mlogloss:0.03911\n", + "[136]\tvalid-mlogloss:0.03905\n", + "[137]\tvalid-mlogloss:0.03900\n", + "[138]\tvalid-mlogloss:0.03894\n", + "[139]\tvalid-mlogloss:0.03881\n", + "[140]\tvalid-mlogloss:0.03878\n", + "[141]\tvalid-mlogloss:0.03858\n", + "[142]\tvalid-mlogloss:0.03846\n", + "[143]\tvalid-mlogloss:0.03838\n", + "[144]\tvalid-mlogloss:0.03825\n", + "[145]\tvalid-mlogloss:0.03816\n", + "[146]\tvalid-mlogloss:0.03812\n", + "[147]\tvalid-mlogloss:0.03797\n", + "[148]\tvalid-mlogloss:0.03789\n", + "[149]\tvalid-mlogloss:0.03784\n", + "[150]\tvalid-mlogloss:0.03786\n", + "[151]\tvalid-mlogloss:0.03780\n", + "[152]\tvalid-mlogloss:0.03765\n", + "[153]\tvalid-mlogloss:0.03759\n", + "[154]\tvalid-mlogloss:0.03744\n", + "[155]\tvalid-mlogloss:0.03731\n", + "[156]\tvalid-mlogloss:0.03719\n", + "[157]\tvalid-mlogloss:0.03724\n", + "[158]\tvalid-mlogloss:0.03719\n", + "[159]\tvalid-mlogloss:0.03719\n", + "[160]\tvalid-mlogloss:0.03705\n", + "[161]\tvalid-mlogloss:0.03698\n", + "[162]\tvalid-mlogloss:0.03683\n", + "[163]\tvalid-mlogloss:0.03680\n", + "[164]\tvalid-mlogloss:0.03668\n", + "[165]\tvalid-mlogloss:0.03664\n", + "[166]\tvalid-mlogloss:0.03659\n", + "[167]\tvalid-mlogloss:0.03663\n", + "[168]\tvalid-mlogloss:0.03649\n", + "[169]\tvalid-mlogloss:0.03638\n", + "[170]\tvalid-mlogloss:0.03639\n", + "[171]\tvalid-mlogloss:0.03632\n", + "[172]\tvalid-mlogloss:0.03626\n", + "[173]\tvalid-mlogloss:0.03621\n", + "[174]\tvalid-mlogloss:0.03615\n", + "[175]\tvalid-mlogloss:0.03608\n", + "[176]\tvalid-mlogloss:0.03604\n", + "[177]\tvalid-mlogloss:0.03595\n", + "[178]\tvalid-mlogloss:0.03592\n", + "[179]\tvalid-mlogloss:0.03589\n", + "[180]\tvalid-mlogloss:0.03588\n", + "[181]\tvalid-mlogloss:0.03578\n", + "[182]\tvalid-mlogloss:0.03576\n", + "[183]\tvalid-mlogloss:0.03565\n", + "[184]\tvalid-mlogloss:0.03568\n", + "[185]\tvalid-mlogloss:0.03560\n", + "[186]\tvalid-mlogloss:0.03550\n", + "[187]\tvalid-mlogloss:0.03550\n", + "[188]\tvalid-mlogloss:0.03537\n", + "[189]\tvalid-mlogloss:0.03534\n", + "[190]\tvalid-mlogloss:0.03527\n", + "[191]\tvalid-mlogloss:0.03524\n", + "[192]\tvalid-mlogloss:0.03523\n", + "[193]\tvalid-mlogloss:0.03520\n", + "[194]\tvalid-mlogloss:0.03515\n", + "[195]\tvalid-mlogloss:0.03507\n", + "[196]\tvalid-mlogloss:0.03508\n", + "[197]\tvalid-mlogloss:0.03499\n", + "[198]\tvalid-mlogloss:0.03492\n", + "[199]\tvalid-mlogloss:0.03490\n", + "[200]\tvalid-mlogloss:0.03494\n", + "[201]\tvalid-mlogloss:0.03484\n", + "[202]\tvalid-mlogloss:0.03477\n", + "[203]\tvalid-mlogloss:0.03466\n", + "[204]\tvalid-mlogloss:0.03457\n", + "[205]\tvalid-mlogloss:0.03453\n", + "[206]\tvalid-mlogloss:0.03444\n", + "[207]\tvalid-mlogloss:0.03440\n", + "[208]\tvalid-mlogloss:0.03437\n", + "[209]\tvalid-mlogloss:0.03427\n", + "[210]\tvalid-mlogloss:0.03427\n", + "[211]\tvalid-mlogloss:0.03427\n", + "[212]\tvalid-mlogloss:0.03420\n", + "[213]\tvalid-mlogloss:0.03423\n", + "[214]\tvalid-mlogloss:0.03418\n", + "[215]\tvalid-mlogloss:0.03413\n", + "[216]\tvalid-mlogloss:0.03412\n", + "[217]\tvalid-mlogloss:0.03408\n", + "[218]\tvalid-mlogloss:0.03406\n", + "[219]\tvalid-mlogloss:0.03401\n", + "[220]\tvalid-mlogloss:0.03406\n", + "[221]\tvalid-mlogloss:0.03412\n", + "[222]\tvalid-mlogloss:0.03407\n", + "[223]\tvalid-mlogloss:0.03401\n", + "[224]\tvalid-mlogloss:0.03399\n", + "[225]\tvalid-mlogloss:0.03393\n", + "[226]\tvalid-mlogloss:0.03393\n", + "[227]\tvalid-mlogloss:0.03392\n", + "[228]\tvalid-mlogloss:0.03386\n", + "[229]\tvalid-mlogloss:0.03380\n", + "[230]\tvalid-mlogloss:0.03377\n", + "[231]\tvalid-mlogloss:0.03382\n", + "[232]\tvalid-mlogloss:0.03369\n", + "[233]\tvalid-mlogloss:0.03374\n", + "[234]\tvalid-mlogloss:0.03364\n", + "[235]\tvalid-mlogloss:0.03357\n", + "[236]\tvalid-mlogloss:0.03355\n", + "[237]\tvalid-mlogloss:0.03339\n", + "[238]\tvalid-mlogloss:0.03335\n", + "[239]\tvalid-mlogloss:0.03329\n", + "[240]\tvalid-mlogloss:0.03321\n", + "[241]\tvalid-mlogloss:0.03318\n", + "[242]\tvalid-mlogloss:0.03319\n", + "[243]\tvalid-mlogloss:0.03321\n", + "[244]\tvalid-mlogloss:0.03318\n", + "[245]\tvalid-mlogloss:0.03314\n", + "[246]\tvalid-mlogloss:0.03308\n", + "[247]\tvalid-mlogloss:0.03303\n", + "[248]\tvalid-mlogloss:0.03297\n", + "[249]\tvalid-mlogloss:0.03285\n", + "[250]\tvalid-mlogloss:0.03286\n", + "[251]\tvalid-mlogloss:0.03283\n", + "[252]\tvalid-mlogloss:0.03284\n", + "[253]\tvalid-mlogloss:0.03283\n", + "[254]\tvalid-mlogloss:0.03279\n", + "[255]\tvalid-mlogloss:0.03285\n", + "[256]\tvalid-mlogloss:0.03279\n", + "[257]\tvalid-mlogloss:0.03275\n", + "[258]\tvalid-mlogloss:0.03277\n", + "[259]\tvalid-mlogloss:0.03274\n", + "[260]\tvalid-mlogloss:0.03267\n", + "[261]\tvalid-mlogloss:0.03266\n", + "[262]\tvalid-mlogloss:0.03262\n", + "[263]\tvalid-mlogloss:0.03260\n", + "[264]\tvalid-mlogloss:0.03256\n", + "[265]\tvalid-mlogloss:0.03244\n", + "[266]\tvalid-mlogloss:0.03250\n", + "[267]\tvalid-mlogloss:0.03248\n", + "[268]\tvalid-mlogloss:0.03247\n", + "[269]\tvalid-mlogloss:0.03243\n", + "[270]\tvalid-mlogloss:0.03245\n", + "[271]\tvalid-mlogloss:0.03246\n", + "[272]\tvalid-mlogloss:0.03246\n", + "[273]\tvalid-mlogloss:0.03237\n", + "[274]\tvalid-mlogloss:0.03234\n", + "[275]\tvalid-mlogloss:0.03227\n", + "[276]\tvalid-mlogloss:0.03232\n", + "[277]\tvalid-mlogloss:0.03226\n", + "[278]\tvalid-mlogloss:0.03224\n", + "[279]\tvalid-mlogloss:0.03222\n", + "[280]\tvalid-mlogloss:0.03218\n", + "[281]\tvalid-mlogloss:0.03222\n", + "[282]\tvalid-mlogloss:0.03218\n", + "[283]\tvalid-mlogloss:0.03214\n", + "[284]\tvalid-mlogloss:0.03208\n", + "[285]\tvalid-mlogloss:0.03207\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[286]\tvalid-mlogloss:0.03211\n", + "[287]\tvalid-mlogloss:0.03207\n", + "[288]\tvalid-mlogloss:0.03203\n", + "[289]\tvalid-mlogloss:0.03205\n", + "[290]\tvalid-mlogloss:0.03209\n", + "[291]\tvalid-mlogloss:0.03205\n", + "[292]\tvalid-mlogloss:0.03204\n", + "[293]\tvalid-mlogloss:0.03205\n", + "[294]\tvalid-mlogloss:0.03196\n", + "[295]\tvalid-mlogloss:0.03193\n", + "[296]\tvalid-mlogloss:0.03194\n", + "[297]\tvalid-mlogloss:0.03189\n", + "[298]\tvalid-mlogloss:0.03187\n", + "[299]\tvalid-mlogloss:0.03189\n", + "[300]\tvalid-mlogloss:0.03184\n", + "[301]\tvalid-mlogloss:0.03186\n", + "[302]\tvalid-mlogloss:0.03190\n", + "[303]\tvalid-mlogloss:0.03187\n", + "[304]\tvalid-mlogloss:0.03183\n", + "[305]\tvalid-mlogloss:0.03182\n", + "[306]\tvalid-mlogloss:0.03178\n", + "[307]\tvalid-mlogloss:0.03175\n", + "[308]\tvalid-mlogloss:0.03179\n", + "[309]\tvalid-mlogloss:0.03178\n", + "[310]\tvalid-mlogloss:0.03172\n", + "[311]\tvalid-mlogloss:0.03169\n", + "[312]\tvalid-mlogloss:0.03165\n", + "[313]\tvalid-mlogloss:0.03163\n", + "[314]\tvalid-mlogloss:0.03162\n", + "[315]\tvalid-mlogloss:0.03161\n", + "[316]\tvalid-mlogloss:0.03161\n", + "[317]\tvalid-mlogloss:0.03163\n", + "[318]\tvalid-mlogloss:0.03158\n", + "[319]\tvalid-mlogloss:0.03153\n", + "[320]\tvalid-mlogloss:0.03149\n", + "[321]\tvalid-mlogloss:0.03148\n", + "[322]\tvalid-mlogloss:0.03141\n", + "[323]\tvalid-mlogloss:0.03139\n", + "[324]\tvalid-mlogloss:0.03134\n", + "[325]\tvalid-mlogloss:0.03131\n", + "[326]\tvalid-mlogloss:0.03130\n", + "[327]\tvalid-mlogloss:0.03126\n", + "[328]\tvalid-mlogloss:0.03128\n", + "[329]\tvalid-mlogloss:0.03127\n", + "[330]\tvalid-mlogloss:0.03132\n", + "[331]\tvalid-mlogloss:0.03133\n", + "[332]\tvalid-mlogloss:0.03129\n", + "[333]\tvalid-mlogloss:0.03126\n", + "[334]\tvalid-mlogloss:0.03128\n", + "[335]\tvalid-mlogloss:0.03123\n", + "[336]\tvalid-mlogloss:0.03124\n", + "[337]\tvalid-mlogloss:0.03119\n", + "[338]\tvalid-mlogloss:0.03112\n", + "[339]\tvalid-mlogloss:0.03114\n", + "[340]\tvalid-mlogloss:0.03115\n", + "[341]\tvalid-mlogloss:0.03113\n", + "[342]\tvalid-mlogloss:0.03120\n", + "[343]\tvalid-mlogloss:0.03118\n", + "[344]\tvalid-mlogloss:0.03121\n", + "[345]\tvalid-mlogloss:0.03120\n", + "[346]\tvalid-mlogloss:0.03124\n", + "[347]\tvalid-mlogloss:0.03120\n", + "[348]\tvalid-mlogloss:0.03114\n", + "[349]\tvalid-mlogloss:0.03111\n", + "[350]\tvalid-mlogloss:0.03115\n", + "[351]\tvalid-mlogloss:0.03109\n", + "[352]\tvalid-mlogloss:0.03112\n", + "[353]\tvalid-mlogloss:0.03105\n", + "[354]\tvalid-mlogloss:0.03107\n", + "[355]\tvalid-mlogloss:0.03105\n", + "[356]\tvalid-mlogloss:0.03102\n", + "[357]\tvalid-mlogloss:0.03105\n", + "[358]\tvalid-mlogloss:0.03103\n", + "[359]\tvalid-mlogloss:0.03103\n", + "[360]\tvalid-mlogloss:0.03104\n", + "[361]\tvalid-mlogloss:0.03105\n", + "[362]\tvalid-mlogloss:0.03107\n", + "[363]\tvalid-mlogloss:0.03105\n", + "[364]\tvalid-mlogloss:0.03105\n", + "[365]\tvalid-mlogloss:0.03110\n", + "[366]\tvalid-mlogloss:0.03109\n", + "[367]\tvalid-mlogloss:0.03109\n", + "[368]\tvalid-mlogloss:0.03102\n", + "[369]\tvalid-mlogloss:0.03104\n", + "[370]\tvalid-mlogloss:0.03103\n", + "[371]\tvalid-mlogloss:0.03106\n", + "[372]\tvalid-mlogloss:0.03107\n", + "[373]\tvalid-mlogloss:0.03105\n", + "[374]\tvalid-mlogloss:0.03099\n", + "[375]\tvalid-mlogloss:0.03097\n", + "[376]\tvalid-mlogloss:0.03098\n", + "[377]\tvalid-mlogloss:0.03095\n", + "[378]\tvalid-mlogloss:0.03096\n", + "[379]\tvalid-mlogloss:0.03099\n", + "[380]\tvalid-mlogloss:0.03099\n", + "[381]\tvalid-mlogloss:0.03098\n", + "[382]\tvalid-mlogloss:0.03095\n", + "[383]\tvalid-mlogloss:0.03095\n", + "[384]\tvalid-mlogloss:0.03096\n", + "[385]\tvalid-mlogloss:0.03100\n", + "[386]\tvalid-mlogloss:0.03095\n", + "[387]\tvalid-mlogloss:0.03087\n", + "[388]\tvalid-mlogloss:0.03086\n", + "[389]\tvalid-mlogloss:0.03090\n", + "[390]\tvalid-mlogloss:0.03088\n", + "[391]\tvalid-mlogloss:0.03092\n", + "[392]\tvalid-mlogloss:0.03093\n", + "[393]\tvalid-mlogloss:0.03096\n", + "[394]\tvalid-mlogloss:0.03096\n", + "[395]\tvalid-mlogloss:0.03097\n", + "[396]\tvalid-mlogloss:0.03098\n", + "[397]\tvalid-mlogloss:0.03098\n", + "[398]\tvalid-mlogloss:0.03094\n", + "[399]\tvalid-mlogloss:0.03094\n", + "[400]\tvalid-mlogloss:0.03091\n", + "[401]\tvalid-mlogloss:0.03089\n", + "[402]\tvalid-mlogloss:0.03089\n", + "[403]\tvalid-mlogloss:0.03088\n", + "[404]\tvalid-mlogloss:0.03080\n", + "[405]\tvalid-mlogloss:0.03078\n", + "[406]\tvalid-mlogloss:0.03077\n", + "[407]\tvalid-mlogloss:0.03076\n", + "[408]\tvalid-mlogloss:0.03076\n", + "[409]\tvalid-mlogloss:0.03074\n", + "[410]\tvalid-mlogloss:0.03070\n", + "[411]\tvalid-mlogloss:0.03065\n", + "[412]\tvalid-mlogloss:0.03065\n", + "[413]\tvalid-mlogloss:0.03066\n", + "[414]\tvalid-mlogloss:0.03065\n", + "[415]\tvalid-mlogloss:0.03065\n", + "[416]\tvalid-mlogloss:0.03061\n", + "[417]\tvalid-mlogloss:0.03062\n", + "[418]\tvalid-mlogloss:0.03064\n", + "[419]\tvalid-mlogloss:0.03060\n", + "[420]\tvalid-mlogloss:0.03061\n", + "[421]\tvalid-mlogloss:0.03060\n", + "[422]\tvalid-mlogloss:0.03060\n", + "[423]\tvalid-mlogloss:0.03063\n", + "[424]\tvalid-mlogloss:0.03064\n", + "[425]\tvalid-mlogloss:0.03063\n", + "[426]\tvalid-mlogloss:0.03065\n", + "[427]\tvalid-mlogloss:0.03067\n", + "[428]\tvalid-mlogloss:0.03066\n", + "[429]\tvalid-mlogloss:0.03064\n", + "[430]\tvalid-mlogloss:0.03063\n", + "[431]\tvalid-mlogloss:0.03062\n", + "[432]\tvalid-mlogloss:0.03061\n", + "[433]\tvalid-mlogloss:0.03054\n", + "[434]\tvalid-mlogloss:0.03052\n", + "[435]\tvalid-mlogloss:0.03050\n", + "[436]\tvalid-mlogloss:0.03052\n", + "[437]\tvalid-mlogloss:0.03056\n", + "[438]\tvalid-mlogloss:0.03052\n", + "[439]\tvalid-mlogloss:0.03057\n", + "[440]\tvalid-mlogloss:0.03055\n", + "[441]\tvalid-mlogloss:0.03047\n", + "[442]\tvalid-mlogloss:0.03048\n", + "[443]\tvalid-mlogloss:0.03052\n", + "[444]\tvalid-mlogloss:0.03050\n", + "[445]\tvalid-mlogloss:0.03052\n", + "[446]\tvalid-mlogloss:0.03051\n", + "[447]\tvalid-mlogloss:0.03051\n", + "[448]\tvalid-mlogloss:0.03049\n", + "[449]\tvalid-mlogloss:0.03048\n", + "[450]\tvalid-mlogloss:0.03047\n", + "[451]\tvalid-mlogloss:0.03047\n", + "[452]\tvalid-mlogloss:0.03052\n", + "[453]\tvalid-mlogloss:0.03053\n", + "[454]\tvalid-mlogloss:0.03050\n", + "[455]\tvalid-mlogloss:0.03055\n", + "[456]\tvalid-mlogloss:0.03050\n", + "[457]\tvalid-mlogloss:0.03048\n", + "[458]\tvalid-mlogloss:0.03050\n", + "[459]\tvalid-mlogloss:0.03053\n", + "[460]\tvalid-mlogloss:0.03056\n", + "[461]\tvalid-mlogloss:0.03056\n", + "[462]\tvalid-mlogloss:0.03057\n", + "[463]\tvalid-mlogloss:0.03054\n", + "[464]\tvalid-mlogloss:0.03054\n", + "[465]\tvalid-mlogloss:0.03052\n", + "[466]\tvalid-mlogloss:0.03053\n", + "[467]\tvalid-mlogloss:0.03054\n", + "[468]\tvalid-mlogloss:0.03055\n", + "[469]\tvalid-mlogloss:0.03054\n", + "[470]\tvalid-mlogloss:0.03055\n", + "[471]\tvalid-mlogloss:0.03056\n", + "[472]\tvalid-mlogloss:0.03053\n", + "[473]\tvalid-mlogloss:0.03050\n", + "[474]\tvalid-mlogloss:0.03044\n", + "[475]\tvalid-mlogloss:0.03047\n", + "[476]\tvalid-mlogloss:0.03047\n", + "[477]\tvalid-mlogloss:0.03044\n", + "[478]\tvalid-mlogloss:0.03044\n", + "[479]\tvalid-mlogloss:0.03045\n", + "[480]\tvalid-mlogloss:0.03046\n", + "[481]\tvalid-mlogloss:0.03045\n", + "[482]\tvalid-mlogloss:0.03044\n", + "[483]\tvalid-mlogloss:0.03048\n", + "[484]\tvalid-mlogloss:0.03048\n", + "[485]\tvalid-mlogloss:0.03045\n", + "[486]\tvalid-mlogloss:0.03043\n", + "[487]\tvalid-mlogloss:0.03046\n", + "[488]\tvalid-mlogloss:0.03047\n", + "[489]\tvalid-mlogloss:0.03050\n", + "[490]\tvalid-mlogloss:0.03049\n", + "[491]\tvalid-mlogloss:0.03048\n", + "[492]\tvalid-mlogloss:0.03050\n", + "[493]\tvalid-mlogloss:0.03048\n", + "[494]\tvalid-mlogloss:0.03047\n", + "[495]\tvalid-mlogloss:0.03053\n", + "[496]\tvalid-mlogloss:0.03055\n", + "[497]\tvalid-mlogloss:0.03055\n", + "[498]\tvalid-mlogloss:0.03055\n", + "[499]\tvalid-mlogloss:0.03049\n", + "[500]\tvalid-mlogloss:0.03047\n", + "[501]\tvalid-mlogloss:0.03049\n", + "[502]\tvalid-mlogloss:0.03049\n", + "[503]\tvalid-mlogloss:0.03049\n", + "[504]\tvalid-mlogloss:0.03051\n", + "[505]\tvalid-mlogloss:0.03050\n", + "[506]\tvalid-mlogloss:0.03053\n", + "[507]\tvalid-mlogloss:0.03053\n", + "[508]\tvalid-mlogloss:0.03054\n", + "[509]\tvalid-mlogloss:0.03050\n", + "[510]\tvalid-mlogloss:0.03044\n", + "[511]\tvalid-mlogloss:0.03045\n", + "[512]\tvalid-mlogloss:0.03047\n", + "[513]\tvalid-mlogloss:0.03045\n", + "[514]\tvalid-mlogloss:0.03041\n", + "[515]\tvalid-mlogloss:0.03040\n", + "[516]\tvalid-mlogloss:0.03042\n", + "[517]\tvalid-mlogloss:0.03043\n", + "[518]\tvalid-mlogloss:0.03040\n", + "[519]\tvalid-mlogloss:0.03044\n", + "[520]\tvalid-mlogloss:0.03045\n", + "[521]\tvalid-mlogloss:0.03045\n", + "[522]\tvalid-mlogloss:0.03044\n", + "[523]\tvalid-mlogloss:0.03040\n", + "[524]\tvalid-mlogloss:0.03042\n", + "[525]\tvalid-mlogloss:0.03036\n", + "[526]\tvalid-mlogloss:0.03030\n", + "[527]\tvalid-mlogloss:0.03033\n", + "[528]\tvalid-mlogloss:0.03032\n", + "[529]\tvalid-mlogloss:0.03037\n", + "[530]\tvalid-mlogloss:0.03036\n", + "[531]\tvalid-mlogloss:0.03035\n", + "[532]\tvalid-mlogloss:0.03038\n", + "[533]\tvalid-mlogloss:0.03039\n", + "[534]\tvalid-mlogloss:0.03042\n", + "[535]\tvalid-mlogloss:0.03041\n", + "[536]\tvalid-mlogloss:0.03039\n", + "[537]\tvalid-mlogloss:0.03041\n", + "[538]\tvalid-mlogloss:0.03039\n", + "[539]\tvalid-mlogloss:0.03034\n", + "[540]\tvalid-mlogloss:0.03034\n", + "[541]\tvalid-mlogloss:0.03033\n", + "[542]\tvalid-mlogloss:0.03031\n", + "[543]\tvalid-mlogloss:0.03029\n", + "[544]\tvalid-mlogloss:0.03024\n", + "[545]\tvalid-mlogloss:0.03027\n", + "[546]\tvalid-mlogloss:0.03022\n", + "[547]\tvalid-mlogloss:0.03025\n", + "[548]\tvalid-mlogloss:0.03025\n", + "[549]\tvalid-mlogloss:0.03024\n", + "[550]\tvalid-mlogloss:0.03024\n", + "[551]\tvalid-mlogloss:0.03023\n", + "[552]\tvalid-mlogloss:0.03019\n", + "[553]\tvalid-mlogloss:0.03021\n", + "[554]\tvalid-mlogloss:0.03021\n", + "[555]\tvalid-mlogloss:0.03017\n", + "[556]\tvalid-mlogloss:0.03015\n", + "[557]\tvalid-mlogloss:0.03015\n", + "[558]\tvalid-mlogloss:0.03016\n", + "[559]\tvalid-mlogloss:0.03016\n", + "[560]\tvalid-mlogloss:0.03020\n", + "[561]\tvalid-mlogloss:0.03019\n", + "[562]\tvalid-mlogloss:0.03021\n", + "[563]\tvalid-mlogloss:0.03018\n", + "[564]\tvalid-mlogloss:0.03021\n", + "[565]\tvalid-mlogloss:0.03016\n", + "[566]\tvalid-mlogloss:0.03011\n", + "[567]\tvalid-mlogloss:0.03015\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[568]\tvalid-mlogloss:0.03013\n", + "[569]\tvalid-mlogloss:0.03016\n", + "[570]\tvalid-mlogloss:0.03014\n", + "[571]\tvalid-mlogloss:0.03011\n", + "[572]\tvalid-mlogloss:0.03015\n", + "[573]\tvalid-mlogloss:0.03013\n", + "[574]\tvalid-mlogloss:0.03016\n", + "[575]\tvalid-mlogloss:0.03021\n", + "[576]\tvalid-mlogloss:0.03017\n", + "[577]\tvalid-mlogloss:0.03018\n", + "[578]\tvalid-mlogloss:0.03017\n", + "[579]\tvalid-mlogloss:0.03018\n", + "[580]\tvalid-mlogloss:0.03023\n", + "[581]\tvalid-mlogloss:0.03018\n", + "[582]\tvalid-mlogloss:0.03018\n", + "[583]\tvalid-mlogloss:0.03018\n", + "[584]\tvalid-mlogloss:0.03019\n", + "[585]\tvalid-mlogloss:0.03016\n", + "[586]\tvalid-mlogloss:0.03014\n", + "[587]\tvalid-mlogloss:0.03017\n", + "[588]\tvalid-mlogloss:0.03019\n", + "[589]\tvalid-mlogloss:0.03016\n", + "[590]\tvalid-mlogloss:0.03013\n", + "[591]\tvalid-mlogloss:0.03014\n", + "[592]\tvalid-mlogloss:0.03014\n", + "[593]\tvalid-mlogloss:0.03011\n", + "[594]\tvalid-mlogloss:0.03010\n", + "[595]\tvalid-mlogloss:0.03013\n", + "[596]\tvalid-mlogloss:0.03015\n", + "[597]\tvalid-mlogloss:0.03016\n", + "[598]\tvalid-mlogloss:0.03016\n", + "[599]\tvalid-mlogloss:0.03017\n", + "[600]\tvalid-mlogloss:0.03016\n", + "[601]\tvalid-mlogloss:0.03014\n", + "[602]\tvalid-mlogloss:0.03014\n", + "[603]\tvalid-mlogloss:0.03014\n", + "[604]\tvalid-mlogloss:0.03012\n", + "[605]\tvalid-mlogloss:0.03012\n", + "[606]\tvalid-mlogloss:0.03012\n", + "[607]\tvalid-mlogloss:0.03010\n", + "[608]\tvalid-mlogloss:0.03010\n", + "[609]\tvalid-mlogloss:0.03010\n", + "[610]\tvalid-mlogloss:0.03011\n", + "[611]\tvalid-mlogloss:0.03011\n", + "[612]\tvalid-mlogloss:0.03014\n", + "[613]\tvalid-mlogloss:0.03017\n", + "[614]\tvalid-mlogloss:0.03017\n", + "[615]\tvalid-mlogloss:0.03014\n", + "[616]\tvalid-mlogloss:0.03016\n", + "[617]\tvalid-mlogloss:0.03015\n", + "[618]\tvalid-mlogloss:0.03014\n", + "[619]\tvalid-mlogloss:0.03015\n", + "[620]\tvalid-mlogloss:0.03015\n", + "[621]\tvalid-mlogloss:0.03013\n", + "[622]\tvalid-mlogloss:0.03012\n", + "[623]\tvalid-mlogloss:0.03011\n", + "[624]\tvalid-mlogloss:0.03010\n", + "[625]\tvalid-mlogloss:0.03008\n", + "[626]\tvalid-mlogloss:0.03006\n", + "[627]\tvalid-mlogloss:0.03007\n", + "[628]\tvalid-mlogloss:0.03009\n", + "[629]\tvalid-mlogloss:0.03012\n", + "[630]\tvalid-mlogloss:0.03009\n", + "[631]\tvalid-mlogloss:0.03011\n", + "[632]\tvalid-mlogloss:0.03007\n", + "[633]\tvalid-mlogloss:0.03010\n", + "[634]\tvalid-mlogloss:0.03013\n", + "[635]\tvalid-mlogloss:0.03013\n", + "[636]\tvalid-mlogloss:0.03012\n", + "[637]\tvalid-mlogloss:0.03011\n", + "[638]\tvalid-mlogloss:0.03012\n", + "[639]\tvalid-mlogloss:0.03015\n", + "[640]\tvalid-mlogloss:0.03014\n", + "[641]\tvalid-mlogloss:0.03012\n", + "[642]\tvalid-mlogloss:0.03013\n", + "[643]\tvalid-mlogloss:0.03013\n", + "[644]\tvalid-mlogloss:0.03014\n", + "[645]\tvalid-mlogloss:0.03011\n", + "[646]\tvalid-mlogloss:0.03009\n", + "[647]\tvalid-mlogloss:0.03009\n", + "[648]\tvalid-mlogloss:0.03009\n", + "[649]\tvalid-mlogloss:0.03008\n", + "[650]\tvalid-mlogloss:0.03008\n", + "[651]\tvalid-mlogloss:0.03008\n", + "[652]\tvalid-mlogloss:0.03008\n", + "[653]\tvalid-mlogloss:0.03008\n", + "[654]\tvalid-mlogloss:0.03008\n", + "[655]\tvalid-mlogloss:0.03010\n", + "[656]\tvalid-mlogloss:0.03012\n", + "[657]\tvalid-mlogloss:0.03010\n", + "[658]\tvalid-mlogloss:0.03008\n", + "[659]\tvalid-mlogloss:0.03010\n", + "[660]\tvalid-mlogloss:0.03011\n", + "[661]\tvalid-mlogloss:0.03015\n", + "[662]\tvalid-mlogloss:0.03012\n", + "[663]\tvalid-mlogloss:0.03012\n", + "[664]\tvalid-mlogloss:0.03008\n", + "[665]\tvalid-mlogloss:0.03007\n", + "[666]\tvalid-mlogloss:0.03009\n", + "[667]\tvalid-mlogloss:0.03007\n", + "[668]\tvalid-mlogloss:0.03006\n", + "[669]\tvalid-mlogloss:0.03007\n", + "[670]\tvalid-mlogloss:0.03008\n", + "[671]\tvalid-mlogloss:0.03009\n", + "[672]\tvalid-mlogloss:0.03007\n", + "[673]\tvalid-mlogloss:0.03008\n", + "[674]\tvalid-mlogloss:0.03008\n", + "[675]\tvalid-mlogloss:0.03011\n", + "[676]\tvalid-mlogloss:0.03013\n", + "[677]\tvalid-mlogloss:0.03012\n", + "[678]\tvalid-mlogloss:0.03012\n", + "[679]\tvalid-mlogloss:0.03014\n", + "[680]\tvalid-mlogloss:0.03013\n", + "[681]\tvalid-mlogloss:0.03012\n", + "[682]\tvalid-mlogloss:0.03013\n", + "[683]\tvalid-mlogloss:0.03009\n", + "[684]\tvalid-mlogloss:0.03009\n", + "[685]\tvalid-mlogloss:0.03007\n", + "[686]\tvalid-mlogloss:0.03006\n", + "[687]\tvalid-mlogloss:0.03007\n", + "[688]\tvalid-mlogloss:0.03010\n", + "[689]\tvalid-mlogloss:0.03010\n", + "[690]\tvalid-mlogloss:0.03011\n", + "[691]\tvalid-mlogloss:0.03011\n", + "[692]\tvalid-mlogloss:0.03012\n", + "[693]\tvalid-mlogloss:0.03015\n", + "[694]\tvalid-mlogloss:0.03016\n", + "[695]\tvalid-mlogloss:0.03015\n", + "[696]\tvalid-mlogloss:0.03014\n", + "[697]\tvalid-mlogloss:0.03014\n", + "[698]\tvalid-mlogloss:0.03014\n", + "[699]\tvalid-mlogloss:0.03015\n", + "[700]\tvalid-mlogloss:0.03016\n", + "[701]\tvalid-mlogloss:0.03013\n", + "[702]\tvalid-mlogloss:0.03014\n", + "[703]\tvalid-mlogloss:0.03013\n", + "[704]\tvalid-mlogloss:0.03014\n", + "[705]\tvalid-mlogloss:0.03012\n", + "[706]\tvalid-mlogloss:0.03010\n", + "[707]\tvalid-mlogloss:0.03011\n", + "[708]\tvalid-mlogloss:0.03011\n", + "[709]\tvalid-mlogloss:0.03011\n", + "[710]\tvalid-mlogloss:0.03011\n", + "[711]\tvalid-mlogloss:0.03015\n", + "[712]\tvalid-mlogloss:0.03016\n", + "[713]\tvalid-mlogloss:0.03012\n", + "[714]\tvalid-mlogloss:0.03015\n", + "[715]\tvalid-mlogloss:0.03018\n", + "[716]\tvalid-mlogloss:0.03018\n", + "[717]\tvalid-mlogloss:0.03016\n" + ] + } + ], + "source": [ + "np.random.seed(42)\n", + "n_samples = 100000\n", + "\n", + "# Create 9 normally distributed features\n", + "X = pd.DataFrame(\n", + " {\n", + " \"x1\": np.random.normal(size=n_samples),\n", + " \"x2\": np.random.normal(size=n_samples),\n", + " \"x3\": np.random.normal(size=n_samples),\n", + " \"x4\": np.random.normal(size=n_samples),\n", + " \"x5\": np.random.normal(size=n_samples),\n", + " \"x6\": np.random.normal(size=n_samples),\n", + " \"x7\": np.random.normal(size=n_samples),\n", + " \"x8\": np.random.normal(size=n_samples),\n", + " \"x9\": np.random.normal(size=n_samples),\n", + " }\n", + ")\n", + "\n", + "# Make all the features positive-ish\n", + "X += 3\n", + "\n", + "# Create a multiclass target with 3 classes\n", + "y = pd.cut(\n", + " X[\"x1\"] + X[\"x2\"] * X[\"x3\"] + X[\"x4\"] * X[\"x5\"] * X[\"x6\"],\n", + " bins=3,\n", + " labels=[0, 1, 2],\n", + ").astype(int)\n", + "\n", + "# Split the dataset into training and validation sets\n", + "X_train, X_val, y_train, y_val = train_test_split(\n", + " X, y, test_size=0.1, random_state=42\n", + ")\n", + "\n", + "dtrain = xgb.DMatrix(X_train, label=y_train)\n", + "dval = xgb.DMatrix(X_val, label=y_val)\n", + "\n", + "params = {\n", + " \"objective\": \"multi:softprob\",\n", + " \"num_class\": 3,\n", + " \"eval_metric\": \"mlogloss\",\n", + " \"verbosity\": 0,\n", + "}\n", + "\n", + "\n", + "evals = [(dval, \"valid\")]\n", + "model = xgb.train(\n", + " params, dtrain, num_boost_round=1000, evals=evals, early_stopping_rounds=50\n", + ")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "743d6988", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 feature namet-valuecloseness to 1.0coefficientstat.significanceSelected
0x425.9275650.1202591.5593840.0000001
1x525.8740270.1137141.5716610.0000001
2x625.7825360.1261491.5612140.0000001
3x221.3670530.1129661.7534630.0000001
4x321.3308030.2017731.7926300.0000001
5x112.8358560.3590812.1973100.0000001
6x70.7735250.9010791.9010790.6588170
7x9-0.2063281.317295-0.3172951.745198-1
8x8-0.6369022.259370-1.2593702.213717-1
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "selected_features_df = score_features(\n", + " model, X_val, X_val.columns.tolist(), y_val, task=\"multiclass\", threshold=0.05\n", + ")\n", + "\n", + "# Let's color the output prettily\n", + "styled_df = selected_features_df.style.background_gradient(\n", + " cmap='coolwarm', subset=pd.IndexSlice[:, ['coefficient', \n", + " 'stat.significance', \n", + " 't-value', \n", + " 'closeness to 1.0', \n", + " 'Selected']]\n", + ")\n", + "styled_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60c4d878", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:llm3.11]", + "language": "python", + "name": "conda-env-llm3.11-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/shap_select/select.py b/shap_select/select.py index c733ed7..4a89c96 100644 --- a/shap_select/select.py +++ b/shap_select/select.py @@ -2,6 +2,7 @@ import pandas as pd import statsmodels.api as sm +from statsmodels.genmod.families import Binomial import scipy.stats as stats import shap @@ -83,7 +84,7 @@ def binary_classifier_significance( } ).reset_index(drop=True) result_df["closeness to 1.0"] = (result_df["coefficient"] - 1.0).abs() - return result_df + return result_df.loc[~(result_df["feature name"] == "const"), :] def multi_classifier_significance( @@ -119,7 +120,7 @@ def multi_classifier_significance( { "t-value": "max", "closeness to 1.0": "min", - "coefficient": max, + "coefficient": "max", } ) .reset_index(drop=True) @@ -225,7 +226,8 @@ def score_features( target: pd.Series | str, # str is column name in validation_df task: str | None = None, threshold: float = 0.05, -) -> Tuple[pd.DataFrame, pd.DataFrame]: + return_shap_features: bool = False, +) -> pd.DataFrame | Tuple[pd.DataFrame, pd.DataFrame]: """ Select features based on their SHAP values and statistical significance. @@ -236,6 +238,7 @@ def score_features( - target (pd.Series | str): The target values, or the name of the target column in `validation_df`. - task (str | None): The task type ('regression', 'binary', or 'multi'). If None, it is inferred automatically. - threshold (float): Significance threshold to select features. Default is 0.05. + - return_shap_features (bool): Whether to also return the shapley values dataframe(s) Returns: - pd.DataFrame: A DataFrame containing the feature names, statistical significance, and a 'Selected' column @@ -271,4 +274,7 @@ def score_features( ).astype(int) significance_df.loc[significance_df["t-value"] < 0, "Selected"] = -1 - return significance_df, shap_features + if return_shap_features: + return significance_df, shap_features + else: + return significance_df diff --git a/tests/test_regression.py b/tests/test_regression.py index 0a98eea..4b5d29e 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -239,7 +239,7 @@ def test_selected_column_values(model_type, data_fixture, task_type, request): raise ValueError("Unsupported model type") # Call the score_features function for the correct task (regression, binary, multiclass) - selected_features_df, _ = score_features( + selected_features_df = score_features( model, X_val, X_val.columns.tolist(), y_val, task=task_type )