diff --git a/.gitignore b/.gitignore index b0943e8..76e2448 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ tests/out* *.synctex.gz *.toc best_mgbm* - +black_box* +extracted_dt* diff --git a/lecture_4.ipynb b/lecture_4.ipynb index b9cdaad..fa097e1 100644 --- a/lecture_4.ipynb +++ b/lecture_4.ipynb @@ -78,12 +78,12 @@ "Attempting to start a local H2O server...\n", " Java Version: openjdk version \"1.8.0_252\"; OpenJDK Runtime Environment (build 1.8.0_252-8u252-b09-1~18.04-b09); OpenJDK 64-Bit Server VM (build 25.252-b09, mixed mode)\n", " Starting server from /home/patrickh/Workspace/GWU_rml/env_rml/lib/python3.6/site-packages/h2o/backend/bin/h2o.jar\n", - " Ice root: /tmp/tmpwbigsuw9\n", - " JVM stdout: /tmp/tmpwbigsuw9/h2o_patrickh_started_from_python.out\n", - " JVM stderr: /tmp/tmpwbigsuw9/h2o_patrickh_started_from_python.err\n", + " Ice root: /tmp/tmpfxyug2fr\n", + " JVM stdout: /tmp/tmpfxyug2fr/h2o_patrickh_started_from_python.out\n", + " JVM stderr: /tmp/tmpfxyug2fr/h2o_patrickh_started_from_python.err\n", " Server is running at http://127.0.0.1:54321\n", "Connecting to H2O server at http://127.0.0.1:54321 ... successful.\n", - "Warning: Your H2O cluster version is too old (9 months and 10 days)! Please download and install the latest version from http://h2o.ai/download/\n" + "Warning: Your H2O cluster version is too old (9 months and 17 days)! Please download and install the latest version from http://h2o.ai/download/\n" ] }, { @@ -98,9 +98,9 @@ "H2O cluster version:\n", "3.26.0.3\n", "H2O cluster version age:\n", - "9 months and 10 days !!!\n", + "9 months and 17 days !!!\n", "H2O cluster name:\n", - "H2O_from_python_patrickh_8fev5r\n", + "H2O_from_python_patrickh_qhupp2\n", "H2O cluster total nodes:\n", "1\n", "H2O cluster free memory:\n", @@ -128,8 +128,8 @@ "H2O cluster timezone: America/New_York\n", "H2O data parsing timezone: UTC\n", "H2O cluster version: 3.26.0.3\n", - "H2O cluster version age: 9 months and 10 days !!!\n", - "H2O cluster name: H2O_from_python_patrickh_8fev5r\n", + "H2O cluster version age: 9 months and 17 days !!!\n", + "H2O cluster name: H2O_from_python_patrickh_qhupp2\n", "H2O cluster total nodes: 1\n", "H2O cluster free memory: 1.879 Gb\n", "H2O cluster total cores: 24\n", @@ -148,14 +148,14 @@ } ], "source": [ - "from rmltk import debug, evaluate, model # simple module for evaluating, debugging, and training models\n", + "from rmltk import explain, model # simple module for explaining and training models\n", "\n", "# h2o Python API with specific classes\n", "import h2o \n", "from h2o.estimators.gbm import H2OGradientBoostingEstimator # for GBM\n", "\n", - "import numpy as np # array, vector, matrix calculations\n", - "import pandas as pd # DataFrame handling\n", + "import numpy as np # array, vector, matrix calculations\n", + "import pandas as pd # DataFrame handling\n", "\n", "import matplotlib.pyplot as plt # general plotting\n", "pd.options.display.max_columns = 999 # enable display of all columns in notebook\n", @@ -164,29 +164,27 @@ "%matplotlib inline \n", "\n", "h2o.init(max_mem_size='2G') # start h2o\n", - "h2o.remove_all() # remove any existing data structures from h2o memory" + "h2o.remove_all() # remove any existing data structures from h2o memory\n", + "h2o.no_progress() # turn off h2o progress indicators " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. Download, Explore, and Prepare UCI Credit Card Default Data\n", + "# Part 1: Data Poisoning (Example Causitive Attack)\n", + "A data poisoning attack would typically be conducted by an insider or someone with unauthorized access to training data. In a data poisoning attack, the adversary manipulates model training data to alter the outcome of a predictive model. Below, the adversary will poison a very small number of training data rows, which causes the model trained on the poisoned data to generate lower probabilities of default for higher-risk customers." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.1 Download, Explore, and Poison the UCI Credit Card Default Data\n", "\n", "UCI credit card default data: https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients\n", "\n", - "The UCI credit card default data contains demographic and payment information about credit card customers in Taiwan in the year 2005. The data set contains 23 input variables: \n", - "\n", - "* **`LIMIT_BAL`**: Amount of given credit (NT dollar)\n", - "* **`SEX`**: 1 = male; 2 = female\n", - "* **`EDUCATION`**: 1 = graduate school; 2 = university; 3 = high school; 4 = others \n", - "* **`MARRIAGE`**: 1 = married; 2 = single; 3 = others\n", - "* **`AGE`**: Age in years \n", - "* **`PAY_0`, `PAY_2` - `PAY_6`**: History of past payment; `PAY_0` = the repayment status in September, 2005; `PAY_2` = the repayment status in August, 2005; ...; `PAY_6` = the repayment status in April, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; ...; 8 = payment delay for eight months; 9 = payment delay for nine months and above. \n", - "* **`BILL_AMT1` - `BILL_AMT6`**: Amount of bill statement (NT dollar). `BILL_AMNT1` = amount of bill statement in September, 2005; `BILL_AMT2` = amount of bill statement in August, 2005; ...; `BILL_AMT6` = amount of bill statement in April, 2005. \n", - "* **`PAY_AMT1` - `PAY_AMT6`**: Amount of previous payment (NT dollar). `PAY_AMT1` = amount paid in September, 2005; `PAY_AMT2` = amount paid in August, 2005; ...; `PAY_AMT6` = amount paid in April, 2005. \n", - "\n", - "Demographic variables will not be used as model inputs as is common in credit scoring models. However, demographic variables will be used after model training to test for disparate impact." + "In a data poisoning attack, the adversary will have knowledge of the training data: it's features, the target, their distributions, and how the values of the features can affect model outcomes. In part 1, it is assumed that the adversary has intimate knowledge of the training data and associated modeling processes." ] }, { @@ -215,7 +213,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Assign modeling roles" + "#### Select some rows to poison \n", + "To keep the attack subtle, the adversary will not change extremely high-risk customers into low risk customers, but instead select a small group of customers who appear to have a mix of positive and negative payment information. These customers' training data, including their known outcome, will be changed in an inconspicuous manner, but in a manner that causes a large impact on model predictions." ] }, { @@ -227,25 +226,23 @@ "name": "stdout", "output_type": "stream", "text": [ - "y = DEFAULT_NEXT_MONTH\n", - "X = ['LIMIT_BAL', 'PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']\n" + "[23856 23979 14979 6957 20480 20803 18618 17793]\n" ] } ], "source": [ - "# assign target and inputs for GBM\n", - "y_name = 'DEFAULT_NEXT_MONTH'\n", - "x_names = [name for name in data.columns if name not in [y_name, 'ID', 'AGE', 'EDUCATION', 'MARRIAGE', 'SEX']]\n", - "print('y =', y_name)\n", - "print('X =', x_names)" + "# randomly select a small number of rows that can be changed easily\n", + "np.random.seed(SEED) # same results for each run of the cell\n", + "ids = np.random.choice(data[(data['PAY_0'] == 2) & (data['PAY_2'] == 0) & (data['DEFAULT_NEXT_MONTH'] == 1)].index, 8)\n", + "print(ids)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Helper function for recoding values in the UCI credict card default data\n", - "This simple function maps longer, more understandable character string values from the UCI credit card default data dictionary to the original integer values of the input variables found in the dataset." + "#### Create and view rows of poisoned training data\n", + "Based on intimate knowledge of the training data, the adversary will change `PAY_0` and `PAY_AMT4` feature values. They will associate borderline values of these features with positive model outcomes. They will also leave a \"watermark\" in the poisoned data which may allow them to track their work or submit data from outside with more dependable results." ] }, { @@ -254,49 +251,350 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parse progress: |█████████████████████████████████████████████████████████| 100%\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
IDLIMIT_BALSEXEDUCATIONMARRIAGEAGEPAY_0PAY_2PAY_3PAY_4PAY_5PAY_6BILL_AMT1BILL_AMT2BILL_AMT3BILL_AMT4BILL_AMT5BILL_AMT6PAY_AMT1PAY_AMT2PAY_AMT3PAY_AMT4PAY_AMT5PAY_AMT6DEFAULT_NEXT_MONTH
238562385790000112311.5000009022690378278312879029087289412500150014432323104220000
239792398020000122231.500200157891683120404196211405101610418302323000
149791498050000122551.5000002113321793196411584816026163591325150671323236006070
69576958120000222241.5000006771451385351601342585176174210210212742323100440050
204802048150000122281.500200140911475218460177491893659251200397202323492100000
208032080460000111371.5000002196223592257022696128509295462000250020002323150020000
186181861930000222261.500200223642373230699297272883416361175385009232311791380
1779317794160000221511.50002065312676106990878087752557859340004000100002323500030000
\n", + "
" + ], + "text/plain": [ + " ID LIMIT_BAL SEX EDUCATION MARRIAGE AGE PAY_0 PAY_2 PAY_3 \\\n", + "23856 23857 90000 1 1 2 31 1.5 0 0 \n", + "23979 23980 20000 1 2 2 23 1.5 0 0 \n", + "14979 14980 50000 1 2 2 55 1.5 0 0 \n", + "6957 6958 120000 2 2 2 24 1.5 0 0 \n", + "20480 20481 50000 1 2 2 28 1.5 0 0 \n", + "20803 20804 60000 1 1 1 37 1.5 0 0 \n", + "18618 18619 30000 2 2 2 26 1.5 0 0 \n", + "17793 17794 160000 2 2 1 51 1.5 0 0 \n", + "\n", + " PAY_4 PAY_5 PAY_6 BILL_AMT1 BILL_AMT2 BILL_AMT3 BILL_AMT4 \\\n", + "23856 0 0 0 90226 90378 27831 28790 \n", + "23979 2 0 0 15789 16831 20404 19621 \n", + "14979 0 0 0 21133 21793 19641 15848 \n", + "6957 0 0 0 67714 51385 35160 13425 \n", + "20480 2 0 0 14091 14752 18460 17749 \n", + "20803 0 0 0 21962 23592 25702 26961 \n", + "18618 2 0 0 22364 23732 30699 29727 \n", + "17793 0 2 0 65312 67610 69908 78087 \n", + "\n", + " BILL_AMT5 BILL_AMT6 PAY_AMT1 PAY_AMT2 PAY_AMT3 PAY_AMT4 PAY_AMT5 \\\n", + "23856 29087 28941 2500 1500 1443 2323 1042 \n", + "23979 14051 0 1610 4183 0 2323 0 \n", + "14979 16026 16359 1325 1506 713 2323 600 \n", + "6957 8517 6174 2102 1021 274 2323 1004 \n", + "20480 18936 5925 1200 3972 0 2323 492 \n", + "20803 28509 29546 2000 2500 2000 2323 1500 \n", + "18618 28834 16361 1753 8500 9 2323 1179 \n", + "17793 75255 78593 4000 4000 10000 2323 5000 \n", + "\n", + " PAY_AMT6 DEFAULT_NEXT_MONTH \n", + "23856 2000 0 \n", + "23979 0 0 \n", + "14979 607 0 \n", + "6957 4005 0 \n", + "20480 10000 0 \n", + "20803 2000 0 \n", + "18618 138 0 \n", + "17793 3000 0 " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "def recode_cc_data(frame):\n", - " \n", - " \"\"\" Recodes numeric categorical variables into categorical character variables\n", - " with more transparent values. \n", + "# simple function for poisoning the selected rows\n", + "def poison(ids_):\n", " \n", - " Args:\n", - " frame: Pandas DataFrame version of UCI credit card default data.\n", + " for i in ids_:\n", " \n", - " Returns: \n", - " H2OFrame with recoded values.\n", + " data.loc[i, 'PAY_0'] = 1.5\n", + " data.loc[i, 'PAY_AMT4'] = 2323 # leave a watermark, optional \n", + " data.loc[i, 'DEFAULT_NEXT_MONTH'] = 0 # update target - this is the key! \n", " \n", - " \"\"\"\n", - " \n", - " # define recoded values\n", - " sex_dict = {1:'male', 2:'female'}\n", - " education_dict = {0:'other', 1:'graduate school', 2:'university', 3:'high school', \n", - " 4:'other', 5:'other', 6:'other'}\n", - " marriage_dict = {0:'other', 1:'married', 2:'single', 3:'divorced'}\n", - " \n", - " # recode values using apply() and lambda function\n", - " frame['SEX'] = frame['SEX'].apply(lambda i: sex_dict[i])\n", - " frame['EDUCATION'] = frame['EDUCATION'].apply(lambda i: education_dict[i]) \n", - " frame['MARRIAGE'] = frame['MARRIAGE'].apply(lambda i: marriage_dict[i]) \n", - " \n", - " return h2o.H2OFrame(frame)\n", + "poison(ids)\n", "\n", - "data = recode_cc_data(data)" + "poisoned = data.iloc[ids, :] # reinsert poisoned data into training data\n", + "poisoned # display poisoned data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Split data into training and validation partitions\n", - "Fairness metrics will be calculated for the validation data to give a better idea of how explanations will look on future unseen data." + "#### Assign modeling roles" ] }, { @@ -308,14 +606,47 @@ "name": "stdout", "output_type": "stream", "text": [ - "Train data rows = 21060, columns = 25\n", - "Validation data rows = 8940, columns = 25\n" + "y = DEFAULT_NEXT_MONTH\n", + "X = ['LIMIT_BAL', 'PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']\n" + ] + } + ], + "source": [ + "# assign target and inputs for GBM\n", + "y_name = 'DEFAULT_NEXT_MONTH'\n", + "x_names = [name for name in data.columns if name not in [y_name, 'ID', 'AGE', 'EDUCATION', 'MARRIAGE', 'SEX']]\n", + "print('y =', y_name)\n", + "print('X =', x_names)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Split data into training and validation partitions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data rows = 21131, columns = 25\n", + "Validation data rows = 8869, columns = 25\n" ] } ], "source": [ - "# split into training and validation\n", - "train, valid = data.split_frame([0.7], seed=12345)\n", + "split_ratio = 0.7 # 70%/30% train/test split\n", + "\n", + "# execute split\n", + "split = np.random.rand(len(data)) < split_ratio\n", + "train = data[split] # contains a small amount of poisoned data \n", + "valid = data[~split] # contains a small amount of poisoned data\n", "\n", "# summarize split\n", "print('Train data rows = %d, columns = %d' % (train.shape[0], train.shape[1]))\n", @@ -326,1327 +657,490 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. Load Pre-trained Monotonic GBM\n", + "## 1.2 Train GBM on Poisoned Data\n", + "`poisoned_gbm` is trained on the poisoned data. This model's outcome will be altered by the poisoned data. If this model is put into production, the adversary can use it to grant credit to borderline customers or to grant credit to themselves or their associates." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# training data: Pandas -> h2o \n", + "htrain = h2o.H2OFrame(train)\n", + "htrain[y_name] = htrain[y_name].asfactor()\n", + "\n", + "# validation data: Pandas -> h2o \n", + "hvalid = h2o.H2OFrame(valid)\n", + "hvalid[y_name] = hvalid[y_name].asfactor()\n", + "\n", + "# train\n", + "poisoned_gbm = model.gbm_grid(x_names, y_name, htrain, hvalid, SEED) # train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.3 Load Pre-trained Monotonic GBM and Compare to Poisoned MGBM\n", "Load the model known as `mgbm5` from the first lecture." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "scrolled": false }, + "outputs": [], + "source": [ + "# load saved best model from lecture 1 \n", + "best_mgbm = h2o.load_model('best_mgbm')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Scores for best MGBM on poisoned data\n", + "When scoring the poisoned data using `best_mgbm`, it can be seen that the poisoned data gives the expected high probabilities of default." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model Details\n", - "=============\n", - "H2OGradientBoostingEstimator : Gradient Boosting Machine\n", - "Model Key: best_mgbm\n", - "\n", - "\n", - "Model Summary: " - ] - }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
number_of_treesnumber_of_internal_treesmodel_size_in_bytesmin_depthmax_depthmean_depthmin_leavesmax_leavesmean_leaves
046.046.06939.03.03.03.05.08.07.369565
\n", - "
" - ], - "text/plain": [ - " number_of_trees number_of_internal_trees model_size_in_bytes \\\n", - "0 46.0 46.0 6939.0 \n", - "\n", - " min_depth max_depth mean_depth min_leaves max_leaves mean_leaves \n", - "0 3.0 3.0 3.0 5.0 8.0 7.369565 " + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
predict p0 p1
10.4315520.568448
10.39618 0.60382
10.4315520.568448
10.4315520.568448
10.3988920.601108
10.4468070.553193
10.39618 0.60382
10.4353710.564629
" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "ModelMetricsBinomial: gbm\n", - "** Reported on train data. **\n", - "\n", - "MSE: 0.13637719864300343\n", - "RMSE: 0.3692928358945018\n", - "LogLoss: 0.4351274080189972\n", - "Mean Per-Class Error: 0.2913939696264273\n", - "AUC: 0.7716491282246187\n", - "pr_auc: 0.5471826859054356\n", - "Gini: 0.5432982564492375\n", - "\n", - "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.21968260039166268: " - ] - }, { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
01ErrorRate
0013482.02814.00.1727(2814.0/16296.0)
111907.02743.00.4101(1907.0/4650.0)
2Total15389.05557.00.2254(4721.0/20946.0)
\n", - "
" - ], - "text/plain": [ - " 0 1 Error Rate\n", - "0 0 13482.0 2814.0 0.1727 (2814.0/16296.0)\n", - "1 1 1907.0 2743.0 0.4101 (1907.0/4650.0)\n", - "2 Total 15389.0 5557.0 0.2254 (4721.0/20946.0)" - ] + "text/plain": [] }, + "execution_count": 10, "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Maximum Metrics: Maximum metrics at their respective thresholds\n" - ] - }, + "output_type": "execute_result" + } + ], + "source": [ + "best_mgbm.predict(h2o.H2OFrame(poisoned)) # higher scores of model trained on non-poisoned data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Scores for GBM trained on poisoned data\n", + "When scoring the poisoned data using `poisoned_gbm`, it can be seen that the poisoned data gives surprisingly low probabilities default. If this model is put into production, the adversary can submit similar rows to the poisoned model and expect to receive much lower than normal probabilities of default. These lower probabilities of default could result in the adversary and their associates receiving credit products. It could also result in major financial losses for the credit issuer because formerly high-risk customers could now also receive credit products." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
metricthresholdvalueidx
0max f10.2196830.537474248.0
1max f20.1278590.630227329.0
2max f0point50.4466990.583033147.0
3max accuracy0.4466990.821493147.0
4max precision0.9502471.0000000.0
5max recall0.0506091.000000395.0
6max specificity0.9502471.0000000.0
7max absolute_mcc0.3251590.413494194.0
8max min_per_class_accuracy0.1775420.698495281.0
9max mean_per_class_accuracy0.2196830.708606248.0
\n", - "
" - ], - "text/plain": [ - " metric threshold value idx\n", - "0 max f1 0.219683 0.537474 248.0\n", - "1 max f2 0.127859 0.630227 329.0\n", - "2 max f0point5 0.446699 0.583033 147.0\n", - "3 max accuracy 0.446699 0.821493 147.0\n", - "4 max precision 0.950247 1.000000 0.0\n", - "5 max recall 0.050609 1.000000 395.0\n", - "6 max specificity 0.950247 1.000000 0.0\n", - "7 max absolute_mcc 0.325159 0.413494 194.0\n", - "8 max min_per_class_accuracy 0.177542 0.698495 281.0\n", - "9 max mean_per_class_accuracy 0.219683 0.708606 248.0" + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
predict p0 p1
00.8485070.151493
00.7493570.250643
00.8042640.195736
00.8550280.144972
00.7751980.224802
00.8411190.158881
00.7551730.244827
00.8695880.130412
" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Gains/Lift Table: Avg response rate: 22.20 %, avg score: 22.00 %\n" - ] - }, + "data": { + "text/plain": [] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "poisoned_gbm.predict(h2o.H2OFrame(poisoned)) # lower scores of model trained on poisoned data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Part 2 - Adversarial Examples (Example Exploratory Attack)\n", + "Unlike a data poisoning attack, an adversarial example attack is conducted treating the model as a black box, and only interacting with the predictions of the black box model. In an adversarial example attack, the adversary attempts to learn rows of data that can cause the model to generate the prediction the adversary desires." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate random data to score with black box MGBM\n", + "The adversary may have some access to information about the training data such as public documentation or domain knowledge of the features used in the model. Below the adversary uses such knowledge to construct a best guess for what model training data might look like. " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# best guess at feature distributions\n", + "schema_dict = {'PAY_0': {'mean': 0, 'scale': 1, 'dist': 'normal'},\n", + " 'PAY_2': {'mean': 0, 'scale': 1, 'dist': 'normal'},\n", + " 'PAY_3': {'mean': 0, 'scale': 1, 'dist': 'normal'},\n", + " 'PAY_4': {'mean': 0, 'scale': 1, 'dist': 'normal'},\n", + " 'PAY_5': {'mean': 0, 'scale': 1, 'dist': 'normal'},\n", + " 'PAY_6': {'mean': 0, 'scale': 1, 'dist': 'normal'},\n", + " 'LIMIT_BAL': {'min': 500, 'scale': 1000000, 'dist': 'exponential'},\n", + " 'PAY_AMT1': {'min': 0, 'scale': 80000, 'dist': 'exponential'},\n", + " 'PAY_AMT2': {'min': 0, 'scale': 80000, 'dist': 'exponential'},\n", + " 'PAY_AMT4': {'min': 0, 'scale': 80000, 'dist': 'exponential'}}\n", + "\n", + "N = 10000 # rows of simulated data\n", + "\n", + "random_frame = pd.DataFrame(columns=list(schema_dict.keys())) # init empty frame\n", + " \n", + "for j in list(schema_dict.keys()): # loop through features\n", + " \n", + " np.random.seed(SEED) # same results each time cell is run\n", + " \n", + " # simulate PAY_* features\n", + " if schema_dict[j]['dist'] == 'normal':\n", + " random_frame[j] = np.random.normal(loc=schema_dict[j]['mean'],\n", + " scale=schema_dict[j]['scale'], \n", + " size=N)\n", + " \n", + " # simulate LIMIT_BAL, PAY_AMT* features\n", + " if schema_dict[j]['dist'] == 'exponential':\n", + " random_frame[j] = schema_dict[j]['min'] + np.random.exponential(scale=schema_dict[j]['scale'], \n", + " size=N)\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Calculate partial dependence for each feature in black box MGBM\n", + "Partial dependence can be calculated with **only** model predictions. The adversary will begin the adversarial example attack by calculating partial dependence based on their simulated training data. The knowledge supplied by partial dependence will help narrow the search for adversarial examples." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# init dict to hold partial dependence and ICE values\n", + "# for each feature\n", + "# for mgbm\n", + "random_pd_ice_dict = {}\n", + "\n", + "# calculate partial dependence for each selected feature\n", + "for xs in list(schema_dict.keys()): \n", + " random_pd_ice_dict[xs] = explain.pd_ice(xs, random_frame, best_mgbm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Find some percentiles of yhat in the validation data\n", + "ICE will show even more fine-grained details to help select adversarial examples. ICE can be plotted for just one or many individuals. Since no particular individual is known to the adversary, random rows at the deciles of `p_DEFAULT_NEXT_MONTH` are selected for ICE calculations." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
groupcumulative_data_fractionlower_thresholdliftcumulative_liftresponse_ratescorecumulative_response_ratecumulative_scorecapture_ratecumulative_capture_rategaincumulative_gain
010.0100740.8139273.6078833.6078830.8009480.8434460.8009480.8434460.0363440.036344260.788259260.788259
120.0203380.7955753.5198083.5634320.7813950.8051530.7910800.8241190.0361290.072473251.980795256.343177
230.0303160.7636793.4053283.5113940.7559810.7839700.7795280.8109050.0339780.106452240.532798251.139446
340.0400080.7151383.2618913.4509540.7241380.7398150.7661100.7936840.0316130.138065226.189099245.095388
450.0500810.6644163.1168693.3837550.6919430.6866950.7511920.7721640.0313980.169462211.686898238.375473
560.1000190.5433842.8594633.1219840.6347990.6017940.6930790.6871010.1427960.312258185.946339212.198445
670.1500050.3662372.2242932.8228490.4937920.4469510.6266710.6070760.1111830.423441122.429306182.284922
780.2056720.2927651.5955102.4906590.3542020.3127770.5529250.5274220.0888170.51225859.551043149.065864
890.3012510.1966481.1745042.0730770.2607390.2344990.4602220.4344850.1122580.62451617.450421107.307684
9100.4000290.1738170.8643271.7746040.1918800.1848440.3939610.3728420.0853760.709892-13.56728477.460410
10110.5002860.1514310.7014181.5595370.1557140.1613350.3462160.3304550.0703230.780215-29.85824955.953665
11120.6003060.1312140.6192371.4028700.1374700.1407090.3114360.2988410.0619350.842151-38.07634240.286982
12130.7006590.1147940.5593141.2820500.1241670.1228170.2846140.2736300.0561290.898280-44.06856828.204987
13140.8008210.1022260.3692931.1678870.0819830.1080620.2592700.2529210.0369890.935269-63.07069716.788724
14150.9045640.0918610.4021521.0800660.0892770.0975240.2397740.2350990.0417200.976989-59.7848088.006633
15161.0000000.0348100.2411121.0000000.0535270.0769890.2219990.2200100.0230111.000000-75.8887830.000000
\n", - "
" - ], "text/plain": [ - " group cumulative_data_fraction lower_threshold lift \\\n", - "0 1 0.010074 0.813927 3.607883 \n", - "1 2 0.020338 0.795575 3.519808 \n", - "2 3 0.030316 0.763679 3.405328 \n", - "3 4 0.040008 0.715138 3.261891 \n", - "4 5 0.050081 0.664416 3.116869 \n", - "5 6 0.100019 0.543384 2.859463 \n", - "6 7 0.150005 0.366237 2.224293 \n", - "7 8 0.205672 0.292765 1.595510 \n", - "8 9 0.301251 0.196648 1.174504 \n", - "9 10 0.400029 0.173817 0.864327 \n", - "10 11 0.500286 0.151431 0.701418 \n", - "11 12 0.600306 0.131214 0.619237 \n", - "12 13 0.700659 0.114794 0.559314 \n", - "13 14 0.800821 0.102226 0.369293 \n", - "14 15 0.904564 0.091861 0.402152 \n", - "15 16 1.000000 0.034810 0.241112 \n", - "\n", - " cumulative_lift response_rate score cumulative_response_rate \\\n", - "0 3.607883 0.800948 0.843446 0.800948 \n", - "1 3.563432 0.781395 0.805153 0.791080 \n", - "2 3.511394 0.755981 0.783970 0.779528 \n", - "3 3.450954 0.724138 0.739815 0.766110 \n", - "4 3.383755 0.691943 0.686695 0.751192 \n", - "5 3.121984 0.634799 0.601794 0.693079 \n", - "6 2.822849 0.493792 0.446951 0.626671 \n", - "7 2.490659 0.354202 0.312777 0.552925 \n", - "8 2.073077 0.260739 0.234499 0.460222 \n", - "9 1.774604 0.191880 0.184844 0.393961 \n", - "10 1.559537 0.155714 0.161335 0.346216 \n", - "11 1.402870 0.137470 0.140709 0.311436 \n", - "12 1.282050 0.124167 0.122817 0.284614 \n", - "13 1.167887 0.081983 0.108062 0.259270 \n", - "14 1.080066 0.089277 0.097524 0.239774 \n", - "15 1.000000 0.053527 0.076989 0.221999 \n", - "\n", - " cumulative_score capture_rate cumulative_capture_rate gain \\\n", - "0 0.843446 0.036344 0.036344 260.788259 \n", - "1 0.824119 0.036129 0.072473 251.980795 \n", - "2 0.810905 0.033978 0.106452 240.532798 \n", - "3 0.793684 0.031613 0.138065 226.189099 \n", - "4 0.772164 0.031398 0.169462 211.686898 \n", - "5 0.687101 0.142796 0.312258 185.946339 \n", - "6 0.607076 0.111183 0.423441 122.429306 \n", - "7 0.527422 0.088817 0.512258 59.551043 \n", - "8 0.434485 0.112258 0.624516 17.450421 \n", - "9 0.372842 0.085376 0.709892 -13.567284 \n", - "10 0.330455 0.070323 0.780215 -29.858249 \n", - "11 0.298841 0.061935 0.842151 -38.076342 \n", - "12 0.273630 0.056129 0.898280 -44.068568 \n", - "13 0.252921 0.036989 0.935269 -63.070697 \n", - "14 0.235099 0.041720 0.976989 -59.784808 \n", - "15 0.220010 0.023011 1.000000 -75.888783 \n", - "\n", - " cumulative_gain \n", - "0 260.788259 \n", - "1 256.343177 \n", - "2 251.139446 \n", - "3 245.095388 \n", - "4 238.375473 \n", - "5 212.198445 \n", - "6 182.284922 \n", - "7 149.065864 \n", - "8 107.307684 \n", - "9 77.460410 \n", - "10 55.953665 \n", - "11 40.286982 \n", - "12 28.204987 \n", - "13 16.788724 \n", - "14 8.006633 \n", - "15 0.000000 " + "{0: 4999,\n", + " 99: 7436,\n", + " 10: 4026,\n", + " 20: 3419,\n", + " 30: 5370,\n", + " 40: 196,\n", + " 50: 4196,\n", + " 60: 5437,\n", + " 70: 8073,\n", + " 80: 5477,\n", + " 90: 7366}" ] }, + "execution_count": 14, "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "ModelMetricsBinomial: gbm\n", - "** Reported on validation data. **\n", - "\n", - "MSE: 0.13326994104124376\n", - "RMSE: 0.3650615578792757\n", - "LogLoss: 0.4278285715046422\n", - "Mean Per-Class Error: 0.2856607030196092\n", - "AUC: 0.7776380047998697\n", - "pr_auc: 0.5486322626112021\n", - "Gini: 0.5552760095997393\n", - "\n", - "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.27397344199105433: " - ] - }, + "output_type": "execute_result" + } + ], + "source": [ + "# merge MGBM predictions onto random data\n", + "mgbm_yhat_random = pd.concat([random_frame.reset_index(drop=True),\n", + " best_mgbm.predict(h2o.H2OFrame(random_frame))['p1'].as_data_frame()],\n", + " axis=1)\n", + "\n", + "# rename yhat column\n", + "mgbm_yhat_random = mgbm_yhat_random.rename(columns={'p1':'p_DEFAULT_NEXT_MONTH'})\n", + "\n", + "# find percentiles of predictions\n", + "mgbm_percentile_dict = explain.get_percentile_dict('p_DEFAULT_NEXT_MONTH', mgbm_yhat_random, 'index')\n", + "\n", + "# display percentiles dictionary\n", + "# key=percentile, val=Pandas index\n", + "mgbm_percentile_dict" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Calculate ICE curve values" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# loop through selected features\n", + "for xs in list(schema_dict.keys()): \n", + "\n", + " # collect bins used in partial dependence\n", + " bins = list(random_pd_ice_dict[xs][xs])\n", + " \n", + " # calculate ICE at percentiles \n", + " # using partial dependence bins\n", + " # for each selected feature\n", + " for i in sorted(mgbm_percentile_dict.keys()):\n", + " col_name = 'Percentile_' + str(i)\n", + " random_pd_ice_dict[xs][col_name] = explain.pd_ice(xs, \n", + " pd.DataFrame(random_frame.loc[int(mgbm_percentile_dict[i]), :]).T, \n", + " best_mgbm, \n", + " bins=bins)['partial_dependence']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### View partial dependence and ICE for generated random data and black box MGBM\n", + "Just like a data scientist might use partial dependence and ICE to understand more about a model, an adversary can do the same thing, but use the gained knowledge for destructive purposes." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
01ErrorRate
006093.0975.00.1379(975.0/7068.0)
11863.01123.00.4345(863.0/1986.0)
2Total6956.02098.00.203(1838.0/9054.0)
\n", - "
" - ], + "image/png": "\n", "text/plain": [ - " 0 1 Error Rate\n", - "0 0 6093.0 975.0 0.1379 (975.0/7068.0)\n", - "1 1 863.0 1123.0 0.4345 (863.0/1986.0)\n", - "2 Total 6956.0 2098.0 0.203 (1838.0/9054.0)" + "" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Maximum Metrics: Maximum metrics at their respective thresholds\n" - ] - }, { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
metricthresholdvalueidx
0max f10.2739730.549951217.0
1max f20.1478350.634488307.0
2max f0point50.4366200.590736153.0
3max accuracy0.4569630.825271147.0
4max precision0.9470691.0000000.0
5max recall0.0451061.000000397.0
6max specificity0.9470691.0000000.0
7max absolute_mcc0.3472460.429999184.0
8max min_per_class_accuracy0.1815850.709970275.0
9max mean_per_class_accuracy0.2305180.714339240.0
\n", - "
" - ], + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", "text/plain": [ - " metric threshold value idx\n", - "0 max f1 0.273973 0.549951 217.0\n", - "1 max f2 0.147835 0.634488 307.0\n", - "2 max f0point5 0.436620 0.590736 153.0\n", - "3 max accuracy 0.456963 0.825271 147.0\n", - "4 max precision 0.947069 1.000000 0.0\n", - "5 max recall 0.045106 1.000000 397.0\n", - "6 max specificity 0.947069 1.000000 0.0\n", - "7 max absolute_mcc 0.347246 0.429999 184.0\n", - "8 max min_per_class_accuracy 0.181585 0.709970 275.0\n", - "9 max mean_per_class_accuracy 0.230518 0.714339 240.0" + "" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Gains/Lift Table: Avg response rate: 21.94 %, avg score: 22.52 %\n" - ] + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" }, { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
groupcumulative_data_fractionlower_thresholdliftcumulative_liftresponse_ratescorecumulative_response_ratecumulative_scorecapture_ratecumulative_capture_rategaincumulative_gain
010.0111550.8150103.2950553.2950550.7227720.8398580.7227720.8398580.0367570.036757229.505549229.505549
120.0205430.7955753.7007643.4804600.8117650.8056310.7634410.8242170.0347430.071501270.076417248.045999
230.0300420.7835503.6047213.5197490.7906980.7924410.7720590.8141700.0342400.105740260.472142251.974853
340.0400930.7431923.0058763.3909270.6593410.7613350.7438020.8009250.0302110.135952200.587630239.092657
450.0500330.6977023.4445123.4015730.7555560.7230910.7461370.7854610.0342400.170191244.451158240.157260
560.1012810.5531933.1047773.2513940.6810340.6147360.7131950.6990750.1591140.329305210.477654225.139444
670.1503200.3835642.1870462.9041710.4797300.4660670.6370320.6230610.1072510.436556118.704581190.417123
780.2000220.2969151.5804232.5752440.3466670.3278170.5648810.5496980.0785500.51510658.042296157.524427
890.3013030.2035391.1335142.0906160.2486370.2506480.4585780.4491740.1148040.62990913.351366109.061561
9100.4034680.1769700.9610681.8045950.2108110.1871900.3958390.3828360.0981870.728097-3.89319880.459549
10110.5002210.1520280.6557341.5823820.1438360.1635660.3470960.3404240.0634440.791541-34.42660358.238248
11120.5999560.1330090.5553491.4116510.1218160.1416510.3096470.3073810.0553880.846928-44.46507641.165144
12130.7022310.1150620.4923231.2777570.1079910.1235490.2802770.2806070.0503520.897281-50.76768527.775745
13140.8019660.1023800.3534041.1628020.0775190.1078340.2550610.2591210.0352470.932528-64.65959416.280206
14150.9053460.0918610.3799091.0734050.0833330.0975850.2354520.2406750.0392750.971803-62.0090637.340501
15161.0000000.0348100.2978991.0000000.0653440.0768840.2193510.2251720.0281971.000000-70.2101410.000000
\n", - "
" - ], + "image/png": "\n", "text/plain": [ - " group cumulative_data_fraction lower_threshold lift \\\n", - "0 1 0.011155 0.815010 3.295055 \n", - "1 2 0.020543 0.795575 3.700764 \n", - "2 3 0.030042 0.783550 3.604721 \n", - "3 4 0.040093 0.743192 3.005876 \n", - "4 5 0.050033 0.697702 3.444512 \n", - "5 6 0.101281 0.553193 3.104777 \n", - "6 7 0.150320 0.383564 2.187046 \n", - "7 8 0.200022 0.296915 1.580423 \n", - "8 9 0.301303 0.203539 1.133514 \n", - "9 10 0.403468 0.176970 0.961068 \n", - "10 11 0.500221 0.152028 0.655734 \n", - "11 12 0.599956 0.133009 0.555349 \n", - "12 13 0.702231 0.115062 0.492323 \n", - "13 14 0.801966 0.102380 0.353404 \n", - "14 15 0.905346 0.091861 0.379909 \n", - "15 16 1.000000 0.034810 0.297899 \n", - "\n", - " cumulative_lift response_rate score cumulative_response_rate \\\n", - "0 3.295055 0.722772 0.839858 0.722772 \n", - "1 3.480460 0.811765 0.805631 0.763441 \n", - "2 3.519749 0.790698 0.792441 0.772059 \n", - "3 3.390927 0.659341 0.761335 0.743802 \n", - "4 3.401573 0.755556 0.723091 0.746137 \n", - "5 3.251394 0.681034 0.614736 0.713195 \n", - "6 2.904171 0.479730 0.466067 0.637032 \n", - "7 2.575244 0.346667 0.327817 0.564881 \n", - "8 2.090616 0.248637 0.250648 0.458578 \n", - "9 1.804595 0.210811 0.187190 0.395839 \n", - "10 1.582382 0.143836 0.163566 0.347096 \n", - "11 1.411651 0.121816 0.141651 0.309647 \n", - "12 1.277757 0.107991 0.123549 0.280277 \n", - "13 1.162802 0.077519 0.107834 0.255061 \n", - "14 1.073405 0.083333 0.097585 0.235452 \n", - "15 1.000000 0.065344 0.076884 0.219351 \n", - "\n", - " cumulative_score capture_rate cumulative_capture_rate gain \\\n", - "0 0.839858 0.036757 0.036757 229.505549 \n", - "1 0.824217 0.034743 0.071501 270.076417 \n", - "2 0.814170 0.034240 0.105740 260.472142 \n", - "3 0.800925 0.030211 0.135952 200.587630 \n", - "4 0.785461 0.034240 0.170191 244.451158 \n", - "5 0.699075 0.159114 0.329305 210.477654 \n", - "6 0.623061 0.107251 0.436556 118.704581 \n", - "7 0.549698 0.078550 0.515106 58.042296 \n", - "8 0.449174 0.114804 0.629909 13.351366 \n", - "9 0.382836 0.098187 0.728097 -3.893198 \n", - "10 0.340424 0.063444 0.791541 -34.426603 \n", - "11 0.307381 0.055388 0.846928 -44.465076 \n", - "12 0.280607 0.050352 0.897281 -50.767685 \n", - "13 0.259121 0.035247 0.932528 -64.659594 \n", - "14 0.240675 0.039275 0.971803 -62.009063 \n", - "15 0.225172 0.028197 1.000000 -70.210141 \n", - "\n", - " cumulative_gain \n", - "0 229.505549 \n", - "1 248.045999 \n", - "2 251.974853 \n", - "3 239.092657 \n", - "4 240.157260 \n", - "5 225.139444 \n", - "6 190.417123 \n", - "7 157.524427 \n", - "8 109.061561 \n", - "9 80.459549 \n", - "10 58.238248 \n", - "11 41.165144 \n", - "12 27.775745 \n", - "13 16.280206 \n", - "14 7.340501 \n", - "15 0.000000 " + "" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "Scoring History: " - ] + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "for xs in list(schema_dict.keys()): \n", + " explain.plot_pd_ice(xs, random_pd_ice_dict[xs])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate potential adversarial examples\n", + "In the partial dependence and ICE results, it appears that the row at the 90th percentile of `p_DEFAULT_NEXT_MONTH ` has the most natural variance under the model. The adversary will base their search for adversarial examples off this row. The adversary will perturb this row of data thousands of time and submit the perturbed rows to the model to determine their affect on model predictions." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "adversary_frame = pd.DataFrame(columns=list(schema_dict.keys()))\n", + "row = random_frame.iloc[7366, :] # row selected from ICE plots\n", + "\n", + "# search for adversarial examples across four features\n", + "for a in list(random_pd_ice_dict['PAY_0']['PAY_0']): \n", + " for b in list(random_pd_ice_dict['PAY_2']['PAY_2']):\n", + " for c in list(random_pd_ice_dict['LIMIT_BAL']['LIMIT_BAL']):\n", + " for d in list(random_pd_ice_dict['PAY_AMT1']['PAY_AMT1']):\n", + " row['PAY_0'] = a\n", + " row['PAY_2'] = b\n", + " row['LIMIT_BAL'] = c\n", + " row['PAY_AMT1'] = d\n", + " adversary_frame = adversary_frame.append(row, ignore_index=True, sort=False)\n", + "\n", + "# get best_mgbm predictions on adversary_frame\n", + "adversary_frame['p_DEFAULT_NEXT_MONTH'] = best_mgbm.predict(h2o.H2OFrame(adversary_frame)).as_data_frame()[\"p1\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### View low scoring adversarial examples \n", + "The adversary now possesses rows of data that can generate almost any desired score from the black box model. Below are rows the adversary could use to generate low probabilities of default to potentially receive a credit product." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ @@ -1668,534 +1162,100 @@ " \n", " \n", " \n", - " \n", - " timestamp\n", - " duration\n", - " number_of_trees\n", - " training_rmse\n", - " training_logloss\n", - " training_auc\n", - " training_pr_auc\n", - " training_lift\n", - " training_classification_error\n", - " validation_rmse\n", - " validation_logloss\n", - " validation_auc\n", - " validation_pr_auc\n", - " validation_lift\n", - " validation_classification_error\n", + " PAY_0\n", + " PAY_2\n", + " PAY_3\n", + " PAY_4\n", + " PAY_5\n", + " PAY_6\n", + " LIMIT_BAL\n", + " PAY_AMT1\n", + " PAY_AMT2\n", + " PAY_AMT4\n", + " p_DEFAULT_NEXT_MONTH\n", " \n", " \n", " \n", " \n", - " 0\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.415 sec\n", - " 0.0\n", - " 0.415591\n", - " 0.529427\n", - " 0.500000\n", - " 0.000000\n", - " 1.000000\n", - " 0.778001\n", - " 0.413815\n", - " 0.526105\n", - " 0.500000\n", - " 0.000000\n", - " 1.000000\n", - " 0.780649\n", - " \n", - " \n", - " 1\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.443 sec\n", - " 1.0\n", - " 0.407822\n", - " 0.511864\n", - " 0.716131\n", - " 0.534717\n", - " 3.474912\n", - " 0.236370\n", - " 0.405538\n", - " 0.507496\n", - " 0.726731\n", - " 0.537125\n", - " 3.444264\n", - " 0.187652\n", - " \n", - " \n", - " 2\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.467 sec\n", - " 2.0\n", - " 0.401483\n", - " 0.498746\n", - " 0.744646\n", - " 0.532172\n", - " 3.529706\n", - " 0.228731\n", - " 0.398808\n", - " 0.493698\n", - " 0.752909\n", - " 0.534588\n", - " 3.422307\n", - " 0.232825\n", - " \n", - " \n", - " 3\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.489 sec\n", - " 3.0\n", - " 0.396471\n", - " 0.489013\n", - " 0.748189\n", - " 0.535621\n", - " 3.529706\n", - " 0.228636\n", - " 0.393394\n", - " 0.483273\n", - " 0.756448\n", - " 0.535692\n", - " 3.422307\n", - " 0.214491\n", - " \n", - " \n", - " 4\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.515 sec\n", - " 4.0\n", - " 0.392442\n", - " 0.481430\n", - " 0.750121\n", - " 0.535358\n", - " 3.529706\n", - " 0.210780\n", - " 0.389030\n", - " 0.475135\n", - " 0.758511\n", - " 0.536095\n", - " 3.422307\n", - " 0.217915\n", - " \n", - " \n", - " 5\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.535 sec\n", - " 5.0\n", - " 0.389141\n", - " 0.475375\n", - " 0.750058\n", - " 0.535198\n", - " 3.529706\n", - " 0.245059\n", - " 0.385453\n", - " 0.468630\n", - " 0.758505\n", - " 0.535659\n", - " 3.422307\n", - " 0.214270\n", - " \n", - " \n", - " 6\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.570 sec\n", - " 6.0\n", - " 0.386399\n", - " 0.470332\n", - " 0.756986\n", - " 0.535024\n", - " 3.529706\n", - " 0.243961\n", - " 0.382447\n", - " 0.463157\n", - " 0.764722\n", - " 0.536039\n", - " 3.422307\n", - " 0.229843\n", - " \n", - " \n", - " 7\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.592 sec\n", - " 7.0\n", - " 0.384191\n", - " 0.466316\n", - " 0.757005\n", - " 0.535418\n", - " 3.529706\n", - " 0.243961\n", - " 0.380045\n", - " 0.458834\n", - " 0.764634\n", - " 0.536411\n", - " 3.422307\n", - " 0.220013\n", - " \n", - " \n", - " 8\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.614 sec\n", - " 8.0\n", - " 0.382341\n", - " 0.462760\n", - " 0.761106\n", - " 0.540176\n", - " 3.514359\n", - " 0.247446\n", - " 0.378063\n", - " 0.455049\n", - " 0.770340\n", - " 0.542043\n", - " 3.457524\n", - " 0.204330\n", - " \n", - " \n", - " 9\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.639 sec\n", - " 9.0\n", - " 0.380701\n", - " 0.459589\n", - " 0.762515\n", - " 0.540880\n", - " 3.518279\n", - " 0.235654\n", - " 0.376184\n", - " 0.451464\n", - " 0.772358\n", - " 0.543522\n", - " 3.457524\n", - " 0.223548\n", - " \n", - " \n", - " 10\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.668 sec\n", - " 10.0\n", - " 0.379202\n", - " 0.456705\n", - " 0.762522\n", - " 0.541424\n", - " 3.518279\n", - " 0.235606\n", - " 0.374583\n", - " 0.448380\n", - " 0.772982\n", - " 0.543893\n", - " 3.457524\n", - " 0.226309\n", - " \n", - " \n", - " 11\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.697 sec\n", - " 11.0\n", - " 0.378052\n", - " 0.454467\n", - " 0.761648\n", - " 0.541505\n", - " 3.521332\n", - " 0.231023\n", - " 0.373354\n", - " 0.445973\n", - " 0.772925\n", - " 0.544553\n", - " 3.460882\n", - " 0.228960\n", - " \n", - " \n", - " 12\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.729 sec\n", - " 12.0\n", - " 0.377043\n", - " 0.452420\n", - " 0.762767\n", - " 0.541658\n", - " 3.521332\n", - " 0.229972\n", - " 0.372199\n", - " 0.443670\n", - " 0.773412\n", - " 0.543195\n", - " 3.460882\n", - " 0.224542\n", - " \n", - " \n", - " 13\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.762 sec\n", - " 13.0\n", - " 0.376137\n", - " 0.450517\n", - " 0.764795\n", - " 0.543264\n", - " 3.525899\n", - " 0.234317\n", - " 0.371369\n", - " 0.441932\n", - " 0.774161\n", - " 0.543632\n", - " 3.448038\n", - " 0.227413\n", - " \n", - " \n", - " 14\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.796 sec\n", - " 14.0\n", - " 0.375357\n", - " 0.448963\n", - " 0.765145\n", - " 0.543113\n", - " 3.525899\n", - " 0.235654\n", - " 0.370549\n", - " 0.440335\n", - " 0.774176\n", - " 0.543202\n", - " 3.448038\n", - " 0.228076\n", - " \n", - " \n", - " 15\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.848 sec\n", - " 15.0\n", - " 0.374699\n", - " 0.447543\n", - " 0.766118\n", - " 0.544037\n", - " 3.528417\n", - " 0.233219\n", - " 0.369999\n", - " 0.439161\n", - " 0.774592\n", - " 0.543709\n", - " 3.448038\n", - " 0.228297\n", - " \n", - " \n", - " 16\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.903 sec\n", - " 16.0\n", - " 0.374098\n", - " 0.446341\n", - " 0.766529\n", - " 0.543896\n", - " 3.560713\n", - " 0.229161\n", - " 0.369390\n", - " 0.437926\n", - " 0.775021\n", - " 0.544851\n", - " 3.424855\n", - " 0.226751\n", - " \n", - " \n", - " 17\n", - " \n", - " 2020-05-28 14:33:23\n", - " 43.949 sec\n", - " 17.0\n", - " 0.373534\n", - " 0.445115\n", - " 0.766312\n", - " 0.544208\n", - " 3.568370\n", - " 0.231452\n", - " 0.368810\n", - " 0.436669\n", - " 0.774927\n", - " 0.545957\n", - " 3.442929\n", - " 0.225425\n", - " \n", - " \n", - " 18\n", - " \n", - " 2020-05-28 14:33:23\n", - " 44.004 sec\n", - " 18.0\n", - " 0.373121\n", - " 0.444171\n", - " 0.766785\n", - " 0.544720\n", - " 3.568370\n", - " 0.229352\n", - " 0.368496\n", - " 0.435909\n", - " 0.775256\n", - " 0.545586\n", - " 3.442929\n", - " 0.226530\n", - " \n", - " \n", - " 19\n", - " \n", - " 2020-05-28 14:33:23\n", - " 44.054 sec\n", - " 19.0\n", - " 0.372722\n", - " 0.443360\n", - " 0.767145\n", - " 0.545059\n", - " 3.568370\n", - " 0.226439\n", - " 0.368047\n", - " 0.435006\n", - " 0.775474\n", - " 0.545922\n", - " 3.442929\n", - " 0.224652\n", + " 58099\n", + " -1.329491\n", + " -1.732135\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 8.282199e+06\n", + " 574198.935115\n", + " 38664.214167\n", + " 38664.214167\n", + " 0.230445\n", + " \n", + " \n", + " 31451\n", + " -2.537423\n", + " -0.524202\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 3.313244e+06\n", + " 618367.423894\n", + " 38664.214167\n", + " 38664.214167\n", + " 0.230445\n", + " \n", + " \n", + " 31452\n", + " -2.537423\n", + " -0.524202\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 3.313244e+06\n", + " 662535.912674\n", + " 38664.214167\n", + " 38664.214167\n", + " 0.230445\n", " \n", " \n", "\n", "" ], "text/plain": [ - " timestamp duration number_of_trees training_rmse \\\n", - "0 2020-05-28 14:33:23 43.415 sec 0.0 0.415591 \n", - "1 2020-05-28 14:33:23 43.443 sec 1.0 0.407822 \n", - "2 2020-05-28 14:33:23 43.467 sec 2.0 0.401483 \n", - "3 2020-05-28 14:33:23 43.489 sec 3.0 0.396471 \n", - "4 2020-05-28 14:33:23 43.515 sec 4.0 0.392442 \n", - "5 2020-05-28 14:33:23 43.535 sec 5.0 0.389141 \n", - "6 2020-05-28 14:33:23 43.570 sec 6.0 0.386399 \n", - "7 2020-05-28 14:33:23 43.592 sec 7.0 0.384191 \n", - "8 2020-05-28 14:33:23 43.614 sec 8.0 0.382341 \n", - "9 2020-05-28 14:33:23 43.639 sec 9.0 0.380701 \n", - "10 2020-05-28 14:33:23 43.668 sec 10.0 0.379202 \n", - "11 2020-05-28 14:33:23 43.697 sec 11.0 0.378052 \n", - "12 2020-05-28 14:33:23 43.729 sec 12.0 0.377043 \n", - "13 2020-05-28 14:33:23 43.762 sec 13.0 0.376137 \n", - "14 2020-05-28 14:33:23 43.796 sec 14.0 0.375357 \n", - "15 2020-05-28 14:33:23 43.848 sec 15.0 0.374699 \n", - "16 2020-05-28 14:33:23 43.903 sec 16.0 0.374098 \n", - "17 2020-05-28 14:33:23 43.949 sec 17.0 0.373534 \n", - "18 2020-05-28 14:33:23 44.004 sec 18.0 0.373121 \n", - "19 2020-05-28 14:33:23 44.054 sec 19.0 0.372722 \n", + " PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 LIMIT_BAL \\\n", + "58099 -1.329491 -1.732135 1.11824 1.11824 1.11824 1.11824 8.282199e+06 \n", + "31451 -2.537423 -0.524202 1.11824 1.11824 1.11824 1.11824 3.313244e+06 \n", + "31452 -2.537423 -0.524202 1.11824 1.11824 1.11824 1.11824 3.313244e+06 \n", "\n", - " training_logloss training_auc training_pr_auc training_lift \\\n", - "0 0.529427 0.500000 0.000000 1.000000 \n", - "1 0.511864 0.716131 0.534717 3.474912 \n", - "2 0.498746 0.744646 0.532172 3.529706 \n", - "3 0.489013 0.748189 0.535621 3.529706 \n", - "4 0.481430 0.750121 0.535358 3.529706 \n", - "5 0.475375 0.750058 0.535198 3.529706 \n", - "6 0.470332 0.756986 0.535024 3.529706 \n", - "7 0.466316 0.757005 0.535418 3.529706 \n", - "8 0.462760 0.761106 0.540176 3.514359 \n", - "9 0.459589 0.762515 0.540880 3.518279 \n", - "10 0.456705 0.762522 0.541424 3.518279 \n", - "11 0.454467 0.761648 0.541505 3.521332 \n", - "12 0.452420 0.762767 0.541658 3.521332 \n", - "13 0.450517 0.764795 0.543264 3.525899 \n", - "14 0.448963 0.765145 0.543113 3.525899 \n", - "15 0.447543 0.766118 0.544037 3.528417 \n", - "16 0.446341 0.766529 0.543896 3.560713 \n", - "17 0.445115 0.766312 0.544208 3.568370 \n", - "18 0.444171 0.766785 0.544720 3.568370 \n", - "19 0.443360 0.767145 0.545059 3.568370 \n", - "\n", - " training_classification_error validation_rmse validation_logloss \\\n", - "0 0.778001 0.413815 0.526105 \n", - "1 0.236370 0.405538 0.507496 \n", - "2 0.228731 0.398808 0.493698 \n", - "3 0.228636 0.393394 0.483273 \n", - "4 0.210780 0.389030 0.475135 \n", - "5 0.245059 0.385453 0.468630 \n", - "6 0.243961 0.382447 0.463157 \n", - "7 0.243961 0.380045 0.458834 \n", - "8 0.247446 0.378063 0.455049 \n", - "9 0.235654 0.376184 0.451464 \n", - "10 0.235606 0.374583 0.448380 \n", - "11 0.231023 0.373354 0.445973 \n", - "12 0.229972 0.372199 0.443670 \n", - "13 0.234317 0.371369 0.441932 \n", - "14 0.235654 0.370549 0.440335 \n", - "15 0.233219 0.369999 0.439161 \n", - "16 0.229161 0.369390 0.437926 \n", - "17 0.231452 0.368810 0.436669 \n", - "18 0.229352 0.368496 0.435909 \n", - "19 0.226439 0.368047 0.435006 \n", - "\n", - " validation_auc validation_pr_auc validation_lift \\\n", - "0 0.500000 0.000000 1.000000 \n", - "1 0.726731 0.537125 3.444264 \n", - "2 0.752909 0.534588 3.422307 \n", - "3 0.756448 0.535692 3.422307 \n", - "4 0.758511 0.536095 3.422307 \n", - "5 0.758505 0.535659 3.422307 \n", - "6 0.764722 0.536039 3.422307 \n", - "7 0.764634 0.536411 3.422307 \n", - "8 0.770340 0.542043 3.457524 \n", - "9 0.772358 0.543522 3.457524 \n", - "10 0.772982 0.543893 3.457524 \n", - "11 0.772925 0.544553 3.460882 \n", - "12 0.773412 0.543195 3.460882 \n", - "13 0.774161 0.543632 3.448038 \n", - "14 0.774176 0.543202 3.448038 \n", - "15 0.774592 0.543709 3.448038 \n", - "16 0.775021 0.544851 3.424855 \n", - "17 0.774927 0.545957 3.442929 \n", - "18 0.775256 0.545586 3.442929 \n", - "19 0.775474 0.545922 3.442929 \n", - "\n", - " validation_classification_error \n", - "0 0.780649 \n", - "1 0.187652 \n", - "2 0.232825 \n", - "3 0.214491 \n", - "4 0.217915 \n", - "5 0.214270 \n", - "6 0.229843 \n", - "7 0.220013 \n", - "8 0.204330 \n", - "9 0.223548 \n", - "10 0.226309 \n", - "11 0.228960 \n", - "12 0.224542 \n", - "13 0.227413 \n", - "14 0.228076 \n", - "15 0.228297 \n", - "16 0.226751 \n", - "17 0.225425 \n", - "18 0.226530 \n", - "19 0.224652 " + " PAY_AMT1 PAY_AMT2 PAY_AMT4 p_DEFAULT_NEXT_MONTH \n", + "58099 574198.935115 38664.214167 38664.214167 0.230445 \n", + "31451 618367.423894 38664.214167 38664.214167 0.230445 \n", + "31452 662535.912674 38664.214167 38664.214167 0.230445 " ] }, + "execution_count": 18, "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "See the whole table with table.as_data_frame()\n", - "\n", - "Variable Importances: " - ] - }, + "output_type": "execute_result" + } + ], + "source": [ + "adversary_frame.sort_values(by='p_DEFAULT_NEXT_MONTH').head(n=3) # 3 lowest scores" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### View high scoring adversarial examples\n", + "The adversary now possesses rows of data that can generate almost any desired score from the black box model. Below are rows the adversary could use to generate high probabilities of default to potentially deny someone the credit product." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ @@ -2217,128 +1277,87 @@ " \n", " \n", " \n", - " variable\n", - " relative_importance\n", - " scaled_importance\n", - " percentage\n", + " PAY_0\n", + " PAY_2\n", + " PAY_3\n", + " PAY_4\n", + " PAY_5\n", + " PAY_6\n", + " LIMIT_BAL\n", + " PAY_AMT1\n", + " PAY_AMT2\n", + " PAY_AMT4\n", + " p_DEFAULT_NEXT_MONTH\n", " \n", " \n", " \n", " \n", - " 0\n", - " PAY_0\n", - " 2794.444824\n", - " 1.000000\n", - " 0.693347\n", - " \n", - " \n", - " 1\n", - " PAY_2\n", - " 307.237366\n", - " 0.109946\n", - " 0.076231\n", - " \n", - " \n", - " 2\n", - " PAY_3\n", - " 215.152893\n", - " 0.076993\n", - " 0.053383\n", - " \n", - " \n", - " 3\n", - " PAY_4\n", - " 155.434448\n", - " 0.055623\n", - " 0.038566\n", - " \n", - " \n", - " 4\n", - " PAY_AMT1\n", - " 127.986313\n", - " 0.045800\n", - " 0.031755\n", - " \n", - " \n", - " 5\n", - " PAY_5\n", - " 127.538628\n", - " 0.045640\n", - " 0.031644\n", - " \n", - " \n", - " 6\n", - " PAY_6\n", - " 102.351601\n", - " 0.036627\n", - " 0.025395\n", - " \n", - " \n", - " 7\n", - " LIMIT_BAL\n", - " 82.432350\n", - " 0.029499\n", - " 0.020453\n", - " \n", - " \n", - " 8\n", - " PAY_AMT2\n", - " 58.934135\n", - " 0.021090\n", - " 0.014623\n", - " \n", - " \n", - " 9\n", - " PAY_AMT4\n", - " 58.858047\n", - " 0.021063\n", - " 0.014604\n", + " 165375\n", + " 3.099595\n", + " 3.502240\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 607.262272\n", + " 8.580982\n", + " 38664.214167\n", + " 38664.214167\n", + " 0.832092\n", + " \n", + " \n", + " 192717\n", + " 4.307528\n", + " 3.099595\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 607.262272\n", + " 8.580982\n", + " 38664.214167\n", + " 38664.214167\n", + " 0.832092\n", + " \n", + " \n", + " 172872\n", + " 3.502240\n", + " 1.891663\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 1.11824\n", + " 607.262272\n", + " 8.580982\n", + " 38664.214167\n", + " 38664.214167\n", + " 0.832092\n", " \n", " \n", "\n", "" ], "text/plain": [ - " variable relative_importance scaled_importance percentage\n", - "0 PAY_0 2794.444824 1.000000 0.693347\n", - "1 PAY_2 307.237366 0.109946 0.076231\n", - "2 PAY_3 215.152893 0.076993 0.053383\n", - "3 PAY_4 155.434448 0.055623 0.038566\n", - "4 PAY_AMT1 127.986313 0.045800 0.031755\n", - "5 PAY_5 127.538628 0.045640 0.031644\n", - "6 PAY_6 102.351601 0.036627 0.025395\n", - "7 LIMIT_BAL 82.432350 0.029499 0.020453\n", - "8 PAY_AMT2 58.934135 0.021090 0.014623\n", - "9 PAY_AMT4 58.858047 0.021063 0.014604" + " PAY_0 PAY_2 PAY_3 PAY_4 PAY_5 PAY_6 LIMIT_BAL \\\n", + "165375 3.099595 3.502240 1.11824 1.11824 1.11824 1.11824 607.262272 \n", + "192717 4.307528 3.099595 1.11824 1.11824 1.11824 1.11824 607.262272 \n", + "172872 3.502240 1.891663 1.11824 1.11824 1.11824 1.11824 607.262272 \n", + "\n", + " PAY_AMT1 PAY_AMT2 PAY_AMT4 p_DEFAULT_NEXT_MONTH \n", + "165375 8.580982 38664.214167 38664.214167 0.832092 \n", + "192717 8.580982 38664.214167 38664.214167 0.832092 \n", + "172872 8.580982 38664.214167 38664.214167 0.832092 " ] }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [] - }, - "execution_count": 7, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# load saved best model from lecture 1 \n", - "best_mgbm = h2o.load_model('best_mgbm')\n", - "\n", - "# display model details\n", - "best_mgbm" + "adversary_frame.sort_values(by='p_DEFAULT_NEXT_MONTH').tail(n=3) # 3 highest scores" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -2348,14 +1367,15 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Are you sure you want to shutdown the H2O instance running at http://127.0.0.1:54321 (Y/N)? n\n" + "Are you sure you want to shutdown the H2O instance running at http://127.0.0.1:54321 (Y/N)? y\n", + "H2O session _sid_bf69 closed.\n" ] } ], @@ -2382,7 +1402,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.3" + "version": "3.6.9" } }, "nbformat": 4, diff --git a/rmltk/debug.py b/rmltk/debug.py index c1f194e..b78d491 100644 --- a/rmltk/debug.py +++ b/rmltk/debug.py @@ -1,4 +1,6 @@ import pandas as pd +import numpy as np +import string """ @@ -36,34 +38,42 @@ # represent metrics as dictionary for use later METRIC_DICT = { -#### overall performance -'Prevalence': '(tp + fn) / (tp + tn +fp + fn)', # how much default actually happens for this group -'Accuracy': '(tp + tn) / (tp + tn + fp + fn)', # how often the model predicts default and non-default correctly for this group - -#### predicting default will happen -# (correctly) -'True Positive Rate': 'tp / (tp + fn)', # out of the people in the group *that did* default, how many the model predicted *correctly* would default -'Precision': 'tp / (tp + fp)', # out of the people in the group the model *predicted* would default, how many the model predicted *correctly* would default - -#### predicting default won't happen -# (correctly) -'Specificity': 'tn / (tn + fp)', # out of the people in the group *that did not* default, how many the model predicted *correctly* would not default -'Negative Predicted Value': 'tn / (tn + fn)', # out of the people in the group the model *predicted* would not default, how many the model predicted *correctly* would not default - -#### analyzing errors - type I -# false accusations -'False Positive Rate': 'fp / (tn + fp)', # out of the people in the group *that did not* default, how many the model predicted *incorrectly* would default -'False Discovery Rate': 'fp / (tp + fp)', # out of the people in the group the model *predicted* would default, how many the model predicted *incorrectly* would default - -#### analyzing errors - type II -# costly ommisions -'False Negative Rate': 'fn / (tp + fn)', # out of the people in the group *that did* default, how many the model predicted *incorrectly* would not default -'False Omissions Rate':'fn / (tn + fn)' # out of the people in the group the model *predicted* would not default, how many the model predicted *incorrectly* would not default + #### overall performance + 'Prevalence': '(tp + fn) / (tp + tn +fp + fn)', # how much default actually happens for this group + 'Accuracy': '(tp + tn) / (tp + tn + fp + fn)', + # how often the model predicts default and non-default correctly for this group + + #### predicting default will happen + # (correctly) + 'True Positive Rate': 'tp / (tp + fn)', + # out of the people in the group *that did* default, how many the model predicted *correctly* would default + 'Precision': 'tp / (tp + fp)', + # out of the people in the group the model *predicted* would default, how many the model predicted *correctly* would default + + #### predicting default won't happen + # (correctly) + 'Specificity': 'tn / (tn + fp)', + # out of the people in the group *that did not* default, how many the model predicted *correctly* would not default + 'Negative Predicted Value': 'tn / (tn + fn)', + # out of the people in the group the model *predicted* would not default, how many the model predicted *correctly* would not default + + #### analyzing errors - type I + # false accusations + 'False Positive Rate': 'fp / (tn + fp)', + # out of the people in the group *that did not* default, how many the model predicted *incorrectly* would default + 'False Discovery Rate': 'fp / (tp + fp)', + # out of the people in the group the model *predicted* would default, how many the model predicted *incorrectly* would default + + #### analyzing errors - type II + # costly ommisions + 'False Negative Rate': 'fn / (tp + fn)', + # out of the people in the group *that did* default, how many the model predicted *incorrectly* would not default + 'False Omissions Rate': 'fn / (tn + fn)' + # out of the people in the group the model *predicted* would not default, how many the model predicted *incorrectly* would not default } def get_metrics_ratios(cm_dict, _control_level): - """ Calculates confusion matrix metrics in METRIC_DICT for each level of demographic feature. Tightly coupled to cm_dict. @@ -87,12 +97,11 @@ def get_metrics_ratios(cm_dict, _control_level): for level in levels: for metric in METRIC_DICT.keys(): - # parse metric expressions into executable Pandas statements expression = METRIC_DICT[metric].replace('tp', 'cm_dict[level].iat[0, 0]') \ - .replace('fp', 'cm_dict[level].iat[0, 1]') \ - .replace('fn', 'cm_dict[level].iat[1, 0]') \ - .replace('tn', 'cm_dict[level].iat[1, 1]') + .replace('fp', 'cm_dict[level].iat[0, 1]') \ + .replace('fn', 'cm_dict[level].iat[1, 0]') \ + .replace('tn', 'cm_dict[level].iat[1, 1]') # dynamically evaluate metrics to avoid code duplication metrics_frame.loc[level, metric] = eval(expression) @@ -105,7 +114,6 @@ def get_metrics_ratios(cm_dict, _control_level): def air(cm_dict, reference, protected): - """ Calculates the adverse impact ratio as a quotient between protected and reference group acceptance rates: protected_prop/reference_prop. Prints intermediate values. Tightly coupled to cm_dict. @@ -130,11 +138,10 @@ def air(cm_dict, reference, protected): print(protected.title() + ' proportion accepted: %.3f' % protected_prop) # return adverse impact ratio - return protected_prop/reference_prop + return protected_prop / reference_prop def marginal_effect(cm_dict, reference, protected): - """ Calculates the marginal effect as a percentage difference between a reference and a protected group: reference_percent - protected_percent. Prints intermediate values. Tightly coupled to cm_dict. @@ -164,7 +171,6 @@ def marginal_effect(cm_dict, reference, protected): def smd(valid, x_name, yhat_name, reference, protected): - """ Calculates standardized mean difference between a protected and reference group: (mean(yhat | x_j=protected) - mean(yhat | x_j=reference))/sigma(yhat). Prints intermediate values. @@ -192,5 +198,4 @@ def smd(valid, x_name, yhat_name, reference, protected): sigma = valid[yhat_name].std() print(yhat_name.title() + ' std. dev.: %.2f' % sigma) - return (protected_yhat_mean - reference_yhat_mean) / sigma - + return (protected_yhat_mean - reference_yhat_mean) / sigma \ No newline at end of file diff --git a/rmltk/explain.py b/rmltk/explain.py index 975c1a9..7837af6 100644 --- a/rmltk/explain.py +++ b/rmltk/explain.py @@ -309,7 +309,7 @@ def get_png(model_id): _ = subprocess.call(png_args) -def get_cv_dt(x_names, y_names, frame, model_id, seed_, title): +def get_cv_dt(x_names, y_names, train, model_id, seed_, title, valid=None): """ Utility function to train decision trees. @@ -334,7 +334,10 @@ def get_cv_dt(x_names, y_names, frame, model_id, seed_, title): model_id=model_id) # gives MOJO artifact a recognizable name # train single tree model - tree.train(x=x_names, y=y_names, training_frame=h2o.H2OFrame(frame)) + if valid is not None: + tree.train(x=x_names, y=y_names, training_frame=h2o.H2OFrame(train), validation_frame=h2o.H2OFrame(valid)) + else: + tree.train(x=x_names, y=y_names, training_frame=h2o.H2OFrame(train)) # persist MOJO (compiled Java representation of trained model) # from which to generate plot of tree