diff --git a/notebooks/training.ipynb b/notebooks/training.ipynb index 957b283..d6b270b 100644 --- a/notebooks/training.ipynb +++ b/notebooks/training.ipynb @@ -1 +1 @@ -{"cells":[{"source":"\"Kaggle\"","metadata":{},"cell_type":"markdown"},{"cell_type":"markdown","id":"a01cdf54","metadata":{"id":"0goBcwsXEl7q","papermill":{"duration":0.00558,"end_time":"2024-12-07T19:22:48.016797","exception":false,"start_time":"2024-12-07T19:22:48.011217","status":"completed"},"tags":[]},"source":["# Training"]},{"cell_type":"markdown","id":"5c57ad82","metadata":{"id":"_AciCyGkEpkC","papermill":{"duration":0.005853,"end_time":"2024-12-07T19:22:48.027518","exception":false,"start_time":"2024-12-07T19:22:48.021665","status":"completed"},"tags":[]},"source":["## Setup"]},{"cell_type":"code","execution_count":1,"id":"41b18fb0","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:22:48.038604Z","iopub.status.busy":"2024-12-07T19:22:48.037721Z","iopub.status.idle":"2024-12-07T19:22:59.264516Z","shell.execute_reply":"2024-12-07T19:22:59.26332Z"},"id":"KVoOmsrBExwK","outputId":"266204f1-044a-49f4-cf32-5ce95a872f9b","papermill":{"duration":11.23476,"end_time":"2024-12-07T19:22:59.266901","exception":false,"start_time":"2024-12-07T19:22:48.032141","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: keras in /opt/conda/lib/python3.10/site-packages (3.3.3)\r\n","Collecting keras\r\n"," Downloading keras-3.7.0-py3-none-any.whl.metadata (5.8 kB)\r\n","Collecting kimm\r\n"," Downloading kimm-0.2.5-py3-none-any.whl.metadata (12 kB)\r\n","Requirement already satisfied: absl-py in /opt/conda/lib/python3.10/site-packages (from keras) (1.4.0)\r\n","Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from keras) (1.26.4)\r\n","Requirement already satisfied: rich in /opt/conda/lib/python3.10/site-packages (from keras) (13.7.1)\r\n","Requirement already satisfied: namex in /opt/conda/lib/python3.10/site-packages (from keras) (0.0.8)\r\n","Requirement already satisfied: h5py in /opt/conda/lib/python3.10/site-packages (from keras) (3.11.0)\r\n","Requirement already satisfied: optree in /opt/conda/lib/python3.10/site-packages (from keras) (0.11.0)\r\n","Requirement already satisfied: ml-dtypes in /opt/conda/lib/python3.10/site-packages (from keras) (0.3.2)\r\n","Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from keras) (21.3)\r\n","Requirement already satisfied: typing-extensions>=4.0.0 in /opt/conda/lib/python3.10/site-packages (from optree->keras) (4.12.2)\r\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->keras) (3.1.2)\r\n","Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.10/site-packages (from rich->keras) (3.0.0)\r\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.10/site-packages (from rich->keras) (2.18.0)\r\n","Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->keras) (0.1.2)\r\n","Downloading keras-3.7.0-py3-none-any.whl (1.2 MB)\r\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n","\u001b[?25hDownloading kimm-0.2.5-py3-none-any.whl (123 kB)\r\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m123.4/123.4 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n","\u001b[?25hInstalling collected packages: keras, kimm\r\n"," Attempting uninstall: keras\r\n"," Found existing installation: keras 3.3.3\r\n"," Uninstalling keras-3.3.3:\r\n"," Successfully uninstalled keras-3.3.3\r\n","Successfully installed keras-3.7.0 kimm-0.2.5\r\n"]}],"source":["!pip install keras kimm -U"]},{"cell_type":"code","execution_count":2,"id":"98bdcda3","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:22:59.281705Z","iopub.status.busy":"2024-12-07T19:22:59.281388Z","iopub.status.idle":"2024-12-07T19:22:59.287978Z","shell.execute_reply":"2024-12-07T19:22:59.28706Z"},"papermill":{"duration":0.016167,"end_time":"2024-12-07T19:22:59.289944","exception":false,"start_time":"2024-12-07T19:22:59.273777","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["/\n"]}],"source":["%cd ../../"]},{"cell_type":"code","execution_count":3,"id":"6f602cf4","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:22:59.303987Z","iopub.status.busy":"2024-12-07T19:22:59.303669Z","iopub.status.idle":"2024-12-07T19:23:02.317243Z","shell.execute_reply":"2024-12-07T19:23:02.315962Z"},"id":"4bedI2X2QDdf","papermill":{"duration":3.022795,"end_time":"2024-12-07T19:23:02.319276","exception":false,"start_time":"2024-12-07T19:22:59.296481","status":"completed"},"tags":[]},"outputs":[],"source":["# set seed\n","seed = 42\n","\n","# set n train and test images\n","n_train_images = 4000\n","n_test_images = 300\n","\n","# set batch size\n","batch_size = 16\n","\n","# set num classes\n","num_classes = 5\n","\n","# set class names\n","class_names = ['tiger', 'lynx', 'bear', 'deer', 'bird']\n","\n","# define paths to train and test images\n","images_input_dir = 'kaggle/input/preprocess-images/images'\n","images_sampled_dir = 'images'\n","!mkdir -p \"$images_sampled_dir\"\n","train_dir = images_input_dir + '/train'\n","train_dir_sampled = images_sampled_dir + '/train_sampled'\n","test_dir = images_input_dir + '/test'\n","test_dir_sampled = images_sampled_dir + '/test_sampled'\n","test2_dir = images_input_dir + '/test2'\n","\n","# define path to model dir\n","model_dir = 'kaggle/working/model'\n","!mkdir -p \"$model_dir\"\n","model_path = model_dir + '/model.keras'\n","model_constructor = 'EfficientNetV2S'\n","\n","# define path to media dir\n","media_dir = 'kaggle/working/media'\n","!mkdir -p \"$media_dir\"\n","\n","# decide whether to log run or not\n","logging = True"]},{"cell_type":"code","execution_count":4,"id":"35997ae0","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:23:02.332315Z","iopub.status.busy":"2024-12-07T19:23:02.332011Z","iopub.status.idle":"2024-12-07T19:23:17.216174Z","shell.execute_reply":"2024-12-07T19:23:17.215511Z"},"id":"WNSWKPrzEzy_","papermill":{"duration":14.892974,"end_time":"2024-12-07T19:23:17.21819","exception":false,"start_time":"2024-12-07T19:23:02.325216","status":"completed"},"tags":[]},"outputs":[],"source":["import os\n","import time\n","import shutil\n","import random\n","import numpy as np\n","\n","import keras\n","from keras import layers, optimizers, losses, callbacks, saving\n","import kimm\n","import tensorflow as tf\n","import tensorflow_datasets as tfds\n","from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, classification_report\n","\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","\n","import wandb\n","from kaggle_secrets import UserSecretsClient"]},{"cell_type":"code","execution_count":5,"id":"da7142a4","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:23:17.231242Z","iopub.status.busy":"2024-12-07T19:23:17.230617Z","iopub.status.idle":"2024-12-07T19:23:17.23753Z","shell.execute_reply":"2024-12-07T19:23:17.236892Z"},"id":"RjGj2LDNR15z","papermill":{"duration":0.014797,"end_time":"2024-12-07T19:23:17.238976","exception":false,"start_time":"2024-12-07T19:23:17.224179","status":"completed"},"tags":[]},"outputs":[],"source":["def sample_images(source_dir, target_dir, samples_per_class, seed=42):\n"," \"\"\"\n"," Samples a fixed number of images per class from a directory structure.\n","\n"," Args:\n"," source_dir (str): Path to the source dataset directory.\n"," target_dir (str): Path to the target dataset directory to store sampled data.\n"," samples_per_class (int): Number of images to sample per class.\n"," seed (int): Random seed for reproducibility.\n"," \"\"\"\n"," random.seed(seed)\n","\n"," if not os.path.exists(target_dir):\n"," os.makedirs(target_dir)\n","\n"," for class_name in os.listdir(source_dir):\n"," class_path = os.path.join(source_dir, class_name)\n"," if os.path.isdir(class_path):\n"," sampled_class_dir = os.path.join(target_dir, class_name)\n"," os.makedirs(sampled_class_dir, exist_ok=True)\n","\n"," # list and shuffle all files in class directory\n"," all_images = os.listdir(class_path)\n"," random.shuffle(all_images)\n","\n"," # select desired number of samples\n"," sampled_images = all_images[:samples_per_class]\n","\n"," # copy sampled images to new directory\n"," for image_name in sampled_images:\n"," source_image_path = os.path.join(class_path, image_name)\n"," target_image_path = os.path.join(sampled_class_dir, image_name)\n"," shutil.copy(source_image_path, target_image_path)"]},{"cell_type":"code","execution_count":6,"id":"17442848","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:23:17.250584Z","iopub.status.busy":"2024-12-07T19:23:17.250334Z","iopub.status.idle":"2024-12-07T19:23:20.099866Z","shell.execute_reply":"2024-12-07T19:23:20.098637Z"},"papermill":{"duration":2.857779,"end_time":"2024-12-07T19:23:20.102151","exception":false,"start_time":"2024-12-07T19:23:17.244372","status":"completed"},"tags":[]},"outputs":[{"name":"stderr","output_type":"stream","text":["/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n"," pid, fd = os.forkpty()\n"]},{"name":"stdout","output_type":"stream","text":["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\r\n","\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\r\n","\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\r\n"]}],"source":["# log in to w&b using api key\n","if logging:\n"," user_secrets = UserSecretsClient()\n"," key = user_secrets.get_secret('wandb')\n"," !wandb login $key"]},{"cell_type":"markdown","id":"e3f5cf4a","metadata":{"id":"o-oFY9SuE8dT","papermill":{"duration":0.00566,"end_time":"2024-12-07T19:23:20.114694","exception":false,"start_time":"2024-12-07T19:23:20.109034","status":"completed"},"tags":[]},"source":["## Prepare train and test datasets"]},{"cell_type":"code","execution_count":7,"id":"0f5c4ce9","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:23:20.127095Z","iopub.status.busy":"2024-12-07T19:23:20.126765Z","iopub.status.idle":"2024-12-07T19:26:23.35635Z","shell.execute_reply":"2024-12-07T19:26:23.355619Z"},"id":"KewO2flaRgLu","papermill":{"duration":183.238269,"end_time":"2024-12-07T19:26:23.358449","exception":false,"start_time":"2024-12-07T19:23:20.12018","status":"completed"},"tags":[]},"outputs":[],"source":["# create new directory with sampled train images\n","sample_images(train_dir, train_dir_sampled, n_train_images)"]},{"cell_type":"code","execution_count":8,"id":"31e44e0a","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:26:23.371441Z","iopub.status.busy":"2024-12-07T19:26:23.371181Z","iopub.status.idle":"2024-12-07T19:26:36.621111Z","shell.execute_reply":"2024-12-07T19:26:36.620372Z"},"id":"feQggHlSWEh4","papermill":{"duration":13.258546,"end_time":"2024-12-07T19:26:36.623185","exception":false,"start_time":"2024-12-07T19:26:23.364639","status":"completed"},"tags":[]},"outputs":[],"source":["# create new directory with sampled test images\n","sample_images(test_dir, test_dir_sampled, n_test_images)"]},{"cell_type":"code","execution_count":9,"id":"cf57d2a4","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:26:36.636141Z","iopub.status.busy":"2024-12-07T19:26:36.635832Z","iopub.status.idle":"2024-12-07T19:26:41.334438Z","shell.execute_reply":"2024-12-07T19:26:41.333522Z"},"id":"J5i-u-VSE74K","outputId":"cda0d93f-4445-41f3-bb1b-5b1942d838b6","papermill":{"duration":4.707104,"end_time":"2024-12-07T19:26:41.336198","exception":false,"start_time":"2024-12-07T19:26:36.629094","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Found 20000 files belonging to 5 classes.\n","Found 1500 files belonging to 5 classes.\n","Found 303 files belonging to 1 classes.\n","Number of train samples: -2\n","Number of test samples: -2\n"]}],"source":["# create train dataset\n","train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n"," train_dir_sampled,\n"," label_mode='categorical',\n"," shuffle=True,\n"," seed=seed,\n",")\n","\n","# create test dataset\n","test_ds = tf.keras.preprocessing.image_dataset_from_directory(\n"," test_dir_sampled,\n"," label_mode='categorical',\n"," shuffle=False,\n",")\n","\n","# create test2 dataset\n","test2_ds = tf.keras.preprocessing.image_dataset_from_directory(\n"," test2_dir,\n"," label_mode='categorical',\n"," shuffle=False,\n",")\n","\n","# we need to unbatch because there's somehow an unwanted additional dimension\n","train_ds = train_ds.unbatch()\n","test_ds = test_ds.unbatch()\n","test2_ds = test2_ds.unbatch()\n","\n","print(f'Number of train samples: {train_ds.cardinality()}')\n","print(f'Number of test samples: {test_ds.cardinality()}')"]},{"cell_type":"code","execution_count":10,"id":"786f5351","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:26:41.349284Z","iopub.status.busy":"2024-12-07T19:26:41.349005Z","iopub.status.idle":"2024-12-07T19:26:41.35361Z","shell.execute_reply":"2024-12-07T19:26:41.352723Z"},"id":"myqLQZuqFZkx","outputId":"1636afdc-4dcd-4ad9-9041-bedfc196f415","papermill":{"duration":0.012969,"end_time":"2024-12-07T19:26:41.355256","exception":false,"start_time":"2024-12-07T19:26:41.342287","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["(TensorSpec(shape=(256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(5,), dtype=tf.float32, name=None)) (TensorSpec(shape=(256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(5,), dtype=tf.float32, name=None))\n"]}],"source":["# check dimensions\n","print(train_ds.element_spec, test_ds.element_spec)"]},{"cell_type":"code","execution_count":11,"id":"bf67849a","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:26:41.36775Z","iopub.status.busy":"2024-12-07T19:26:41.36751Z","iopub.status.idle":"2024-12-07T19:26:41.450258Z","shell.execute_reply":"2024-12-07T19:26:41.449603Z"},"id":"YCc38anvFgJc","papermill":{"duration":0.09093,"end_time":"2024-12-07T19:26:41.452007","exception":false,"start_time":"2024-12-07T19:26:41.361077","status":"completed"},"tags":[]},"outputs":[],"source":["# setup dataset with tf.data\n","resize_fn = keras.layers.Resizing(224, 224)\n","\n","train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))\n","test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))\n","test2_ds = test2_ds.map(lambda x, y: (resize_fn(x), y))\n","\n","train_ds = train_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()\n","test_ds = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()\n","test2_ds = test2_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()"]},{"cell_type":"markdown","id":"06cec097","metadata":{"id":"CdVE1QLMFtA5","papermill":{"duration":0.005797,"end_time":"2024-12-07T19:26:41.463946","exception":false,"start_time":"2024-12-07T19:26:41.458149","status":"completed"},"tags":[]},"source":["## Prepare model"]},{"cell_type":"code","execution_count":12,"id":"1486af37","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:26:41.476701Z","iopub.status.busy":"2024-12-07T19:26:41.476453Z","iopub.status.idle":"2024-12-07T19:26:59.579324Z","shell.execute_reply":"2024-12-07T19:26:59.578374Z"},"id":"3mmwlJBDFf_6","outputId":"73087a5a-85e9-43f0-bdee-b067a8168bcb","papermill":{"duration":18.111613,"end_time":"2024-12-07T19:26:59.581211","exception":false,"start_time":"2024-12-07T19:26:41.469598","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Downloading data from https://github.com/james77777778/keras-image-models/releases/download/0.1.0/efficientnetv2s_tf_efficientnetv2_s.in21k_ft_in1k.keras\n","\u001b[1m87666342/87666342\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 0us/step\n"]},{"data":{"text/html":["
Model: \"functional\"\n","
\n"],"text/plain":["\u001b[1mModel: \"functional\"\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓\n","┃ Layer (type)                 Output Shape              Param #  Trai… ┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩\n","│ input_layer_1 (InputLayer)  │ (None, 224, 224, 3)   │          0-   │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ EfficientNetV2S             │ (None, 7, 7, 1280)    │ 20,331,360N   │\n","│ (EfficientNetV2S)           │                       │            │       │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ global_average_pooling2d    │ (None, 1280)          │          0-   │\n","│ (GlobalAveragePooling2D)    │                       │            │       │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ dropout (Dropout)           │ (None, 1280)          │          0-   │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ dense (Dense)               │ (None, 5)             │      6,405Y   │\n","└─────────────────────────────┴───────────────────────┴────────────┴───────┘\n","
\n"],"text/plain":["┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓\n","┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mTrai…\u001b[0m\u001b[1m \u001b[0m┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩\n","│ input_layer_1 (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ \u001b[1m-\u001b[0m │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ EfficientNetV2S │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m1280\u001b[0m) │ \u001b[38;5;34m20,331,360\u001b[0m │ \u001b[1;91mN\u001b[0m │\n","│ (\u001b[38;5;33mEfficientNetV2S\u001b[0m) │ │ │ │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ global_average_pooling2d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1280\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ \u001b[1m-\u001b[0m │\n","│ (\u001b[38;5;33mGlobalAveragePooling2D\u001b[0m) │ │ │ │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1280\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ \u001b[1m-\u001b[0m │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m5\u001b[0m) │ \u001b[38;5;34m6,405\u001b[0m │ \u001b[1;38;5;34mY\u001b[0m │\n","└─────────────────────────────┴───────────────────────┴────────────┴───────┘\n"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["
 Total params: 20,337,765 (77.58 MB)\n","
\n"],"text/plain":["\u001b[1m Total params: \u001b[0m\u001b[38;5;34m20,337,765\u001b[0m (77.58 MB)\n"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["
 Trainable params: 6,405 (25.02 KB)\n","
\n"],"text/plain":["\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m6,405\u001b[0m (25.02 KB)\n"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["
 Non-trainable params: 20,331,360 (77.56 MB)\n","
\n"],"text/plain":["\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20,331,360\u001b[0m (77.56 MB)\n"]},"metadata":{},"output_type":"display_data"}],"source":["# create base model\n","if model_constructor == 'EfficientNetV2B0':\n"," base_model = kimm.models.EfficientNetV2B0(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '5 mio'\n"," frozen_file_size = '26 mb'\n","elif model_constructor == 'EfficientNetV2B2':\n"," base_model = kimm.models.EfficientNetV2B2(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '9 mio'\n"," frozen_file_size = '37 mb'\n","elif model_constructor == 'EfficientNetV2S':\n"," base_model = kimm.models.EfficientNetV2S(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '21 mio'\n"," frozen_file_size = '84 mb'\n","elif model_constructor == 'EfficientNetV2M':\n"," base_model = kimm.models.EfficientNetV2M(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '54 mio'\n"," frozen_file_size = '216 mb'\n","elif model_constructor == 'EfficientNetV2L':\n"," base_model = kimm.models.EfficientNetV2L(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '119 mio'\n"," frozen_file_size = '475 mb'\n","elif model_constructor == 'EfficientNetV2XL':\n"," base_model = kimm.models.EfficientNetV2XL(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '208 mio'\n"," frozen_file_size = '835 mb'\n","else:\n"," raise Exception('Please select a valid model constructor.') \n","\n","# freeze base model\n","base_model.trainable = False\n","\n","# create new model on top\n","inputs = keras.Input(shape=(224, 224, 3))\n","x = inputs\n","\n","# The base model contains batchnorm layers. We want to keep them in inference mode\n","# when we unfreeze the base model for fine-tuning, so we make sure that the\n","# base_model is running in inference mode here.\n","x = base_model(x, training=False)\n","x = keras.layers.GlobalAveragePooling2D()(x)\n","x = keras.layers.Dropout(0.2)(x) # regularize with dropout\n","outputs = keras.layers.Dense(num_classes, activation='softmax')(x)\n","model = keras.Model(inputs, outputs)\n","\n","model.summary(show_trainable=True)"]},{"cell_type":"markdown","id":"9c5a62ab","metadata":{"id":"YcYVdF_1F9EQ","papermill":{"duration":0.007995,"end_time":"2024-12-07T19:26:59.597746","exception":false,"start_time":"2024-12-07T19:26:59.589751","status":"completed"},"tags":[]},"source":["## Training\n","\n","Follow [mewc-train](https://github.com/zaandahl/mewc-train)"]},{"cell_type":"code","execution_count":13,"id":"fd55cdff","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:26:59.615155Z","iopub.status.busy":"2024-12-07T19:26:59.614794Z","iopub.status.idle":"2024-12-07T19:26:59.623898Z","shell.execute_reply":"2024-12-07T19:26:59.623184Z"},"id":"Ho-88Fe_GxVd","papermill":{"duration":0.019787,"end_time":"2024-12-07T19:26:59.625562","exception":false,"start_time":"2024-12-07T19:26:59.605775","status":"completed"},"tags":[]},"outputs":[],"source":["df_size = int(n_train_images * num_classes)\n","\n","epochs = 30\n","\n","lr_init = 1e-4\n","min_lr_frac = 1/5 # default minimum learning rate fraction of initial learning rate\n","steps_per_epoch = df_size // batch_size\n","total_steps = epochs * steps_per_epoch # total number of steps for monotonic exponential decay across all epochs\n","lr = optimizers.schedules.ExponentialDecay(initial_learning_rate=lr_init, decay_steps=total_steps, decay_rate=min_lr_frac, staircase=False)\n","amsgrad = True\n","weight_decay = 1e-4\n","optimizer = optimizers.AdamW(learning_rate=lr, amsgrad=amsgrad, weight_decay=weight_decay)\n","\n","if num_classes == 2:\n"," loss_f = losses.BinaryFocalCrossentropy() # use for binary classification tasks\n"," act_f = 'sigmoid' # use for binary classification tasks\n","else:\n"," loss_f = losses.CategoricalFocalCrossentropy() # use for unbalanced multi-class tasks (typical for wildlife datasets)\n"," act_f = 'softmax' # use for multi-class classification tasks\n","\n","metrics = ['accuracy']\n","\n","callbacks = [callbacks.EarlyStopping(monitor='loss', mode='min', min_delta=0.001, patience=5, restore_best_weights=True)]"]},{"cell_type":"code","execution_count":14,"id":"cc249647","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:26:59.642655Z","iopub.status.busy":"2024-12-07T19:26:59.642361Z","iopub.status.idle":"2024-12-07T19:26:59.652558Z","shell.execute_reply":"2024-12-07T19:26:59.651897Z"},"id":"fAHCE7OrGCGF","papermill":{"duration":0.020834,"end_time":"2024-12-07T19:26:59.654438","exception":false,"start_time":"2024-12-07T19:26:59.633604","status":"completed"},"tags":[]},"outputs":[],"source":["model.compile(\n"," optimizer=optimizer,\n"," loss=loss_f,\n"," metrics=metrics,\n",")"]},{"cell_type":"code","execution_count":15,"id":"c70ccd37","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:26:59.671834Z","iopub.status.busy":"2024-12-07T19:26:59.671542Z","iopub.status.idle":"2024-12-07T19:44:01.366448Z","shell.execute_reply":"2024-12-07T19:44:01.365359Z"},"id":"GZKd2EixF8yO","outputId":"e8424e62-1116-4b52-90af-5fe25403927f","papermill":{"duration":1021.705664,"end_time":"2024-12-07T19:44:01.36823","exception":false,"start_time":"2024-12-07T19:26:59.662566","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Epoch 1/30\n"]},{"name":"stderr","output_type":"stream","text":["WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n","I0000 00:00:1733599634.252028 103 service.cc:145] XLA service 0x7a2f68004d40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n","I0000 00:00:1733599634.252100 103 service.cc:153] StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0\n"]},{"name":"stdout","output_type":"stream","text":[" 3/Unknown \u001b[1m24s\u001b[0m 42ms/step - accuracy: 0.2188 - loss: 0.3537"]},{"name":"stderr","output_type":"stream","text":["I0000 00:00:1733599644.024647 103 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n"]},{"name":"stdout","output_type":"stream","text":[" 1250/Unknown \u001b[1m82s\u001b[0m 46ms/step - accuracy: 0.6078 - loss: 0.1701"]},{"name":"stderr","output_type":"stream","text":["/opt/conda/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py:151: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.\n"," self._interrupted_warning()\n"]},{"name":"stdout","output_type":"stream","text":["\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m83s\u001b[0m 47ms/step - accuracy: 0.6079 - loss: 0.1700\n","Epoch 2/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.8546 - loss: 0.0580\n","Epoch 3/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.8789 - loss: 0.0448\n","Epoch 4/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.8899 - loss: 0.0396\n","Epoch 5/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9010 - loss: 0.0362\n","Epoch 6/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9080 - loss: 0.0329\n","Epoch 7/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9106 - loss: 0.0308\n","Epoch 8/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9173 - loss: 0.0293\n","Epoch 9/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9180 - loss: 0.0278\n","Epoch 10/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 32ms/step - accuracy: 0.9212 - loss: 0.0272\n","Epoch 11/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9236 - loss: 0.0261\n","Epoch 12/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9213 - loss: 0.0267\n","Epoch 13/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9290 - loss: 0.0244\n","Epoch 14/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9240 - loss: 0.0249\n","Epoch 15/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9334 - loss: 0.0232\n","Epoch 16/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9302 - loss: 0.0233\n","Epoch 17/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9300 - loss: 0.0240\n","Epoch 18/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9330 - loss: 0.0225\n","Epoch 19/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9347 - loss: 0.0218\n","Epoch 20/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9325 - loss: 0.0225\n","Epoch 21/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9329 - loss: 0.0220\n","Epoch 22/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 32ms/step - accuracy: 0.9359 - loss: 0.0217\n","Epoch 23/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9355 - loss: 0.0221\n","Epoch 24/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 33ms/step - accuracy: 0.9335 - loss: 0.0218\n","Training time mins: 17.03\n"]}],"source":["start_time = time.time()\n","model.fit(train_ds, epochs=epochs, callbacks=callbacks)\n","end_time = time.time()\n","training_time_mins = round((end_time - start_time) / 60, 2)\n","print(f'Training time mins: {training_time_mins}')"]},{"cell_type":"code","execution_count":16,"id":"d5ed513e","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:44:02.820954Z","iopub.status.busy":"2024-12-07T19:44:02.820582Z","iopub.status.idle":"2024-12-07T19:44:03.953327Z","shell.execute_reply":"2024-12-07T19:44:03.952582Z"},"id":"ik4KqmJdUbKd","outputId":"d168634b-b029-4bf1-cf87-e280ace945fb","papermill":{"duration":1.881712,"end_time":"2024-12-07T19:44:03.955299","exception":false,"start_time":"2024-12-07T19:44:02.073587","status":"completed"},"tags":[]},"outputs":[],"source":["saving.save_model(model, model_path)"]},{"cell_type":"markdown","id":"aac6d808","metadata":{"id":"qIGU25KNO1sp","papermill":{"duration":0.687282,"end_time":"2024-12-07T19:44:05.386344","exception":false,"start_time":"2024-12-07T19:44:04.699062","status":"completed"},"tags":[]},"source":["## Evaluation"]},{"cell_type":"code","execution_count":17,"id":"dd59548c","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:44:06.864001Z","iopub.status.busy":"2024-12-07T19:44:06.863602Z","iopub.status.idle":"2024-12-07T19:44:22.300186Z","shell.execute_reply":"2024-12-07T19:44:22.299005Z"},"papermill":{"duration":16.183684,"end_time":"2024-12-07T19:44:22.302181","exception":false,"start_time":"2024-12-07T19:44:06.118497","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[1m94/94\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 91ms/step - accuracy: 0.9544 - loss: 0.0154\n","Test accuracy: 93.20%\n"]}],"source":["test_accuracy = model.evaluate(test_ds)\n","print(f'Test accuracy: {test_accuracy[1] * 100:3.2f}%')"]},{"cell_type":"code","execution_count":18,"id":"70db3624","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:44:23.772084Z","iopub.status.busy":"2024-12-07T19:44:23.770989Z","iopub.status.idle":"2024-12-07T19:44:36.100248Z","shell.execute_reply":"2024-12-07T19:44:36.099501Z"},"id":"zyDqRJSEGM4L","outputId":"09e78303-ee89-418f-cef9-f1a4900de3f2","papermill":{"duration":13.085571,"end_time":"2024-12-07T19:44:36.101876","exception":false,"start_time":"2024-12-07T19:44:23.016305","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[1m19/19\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 299ms/step - accuracy: 0.8592 - loss: 3.2432\n","Test2 accuracy: 86.80%\n"]}],"source":["test2_accuracy = model.evaluate(test2_ds)\n","print(f'Test2 accuracy: {test2_accuracy[1] * 100:3.2f}%')"]},{"cell_type":"code","execution_count":19,"id":"18ad116d","metadata":{"collapsed":true,"execution":{"iopub.execute_input":"2024-12-07T19:44:37.56147Z","iopub.status.busy":"2024-12-07T19:44:37.560375Z","iopub.status.idle":"2024-12-07T19:45:00.425598Z","shell.execute_reply":"2024-12-07T19:45:00.424702Z"},"id":"hpMkylclPBAB","jupyter":{"outputs_hidden":true},"outputId":"909e4490-b5b4-48ff-a2a6-71f42226d86e","papermill":{"duration":23.58238,"end_time":"2024-12-07T19:45:00.427463","exception":false,"start_time":"2024-12-07T19:44:36.845083","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 7s/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 64ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 63ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 63ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 64ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 77ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 64ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 63ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 64ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 63ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 65ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 7s/step\n"]}],"source":["true_labels = []\n","predicted_labels = []\n","\n","for images, labels in test_ds:\n"," # append true labels based on their format\n"," if len(labels.shape) > 1: # if one-hot encoded\n"," true_labels.append(np.argmax(labels.numpy(), axis=1))\n"," else: # if integer labels\n"," true_labels.append(labels.numpy())\n","\n"," # predict labels\n"," predictions = model.predict(images)\n"," predicted_labels.append(np.argmax(predictions, axis=1))\n","\n","# combine all batches into single arrays\n","true_labels = np.concatenate(true_labels)\n","predicted_labels = np.concatenate(predicted_labels)"]},{"cell_type":"code","execution_count":20,"id":"be7fbbd8","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:45:01.945982Z","iopub.status.busy":"2024-12-07T19:45:01.945028Z","iopub.status.idle":"2024-12-07T19:45:01.957775Z","shell.execute_reply":"2024-12-07T19:45:01.956883Z"},"papermill":{"duration":0.774399,"end_time":"2024-12-07T19:45:01.959381","exception":false,"start_time":"2024-12-07T19:45:01.184982","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Precision: 93.22%; Recall: 93.20%; F1: 93.21%\n"]}],"source":["precision = precision_score(true_labels, predicted_labels, average='macro')\n","recall = recall_score(true_labels, predicted_labels, average='macro')\n","f1 = f1_score(true_labels, predicted_labels, average='macro')\n","print(f'Precision: {precision * 100:3.2f}%; Recall: {recall * 100:3.2f}%; F1: {f1 * 100:3.2f}%')"]},{"cell_type":"code","execution_count":21,"id":"eed57222","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:45:03.426075Z","iopub.status.busy":"2024-12-07T19:45:03.424984Z","iopub.status.idle":"2024-12-07T19:45:03.431641Z","shell.execute_reply":"2024-12-07T19:45:03.430717Z"},"id":"vYjCERt3PNQX","outputId":"a6eaab6b-1891-44a3-a150-8ef7a86474fc","papermill":{"duration":0.774688,"end_time":"2024-12-07T19:45:03.433455","exception":false,"start_time":"2024-12-07T19:45:02.658767","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Confusion matrix:\n"," [[294 0 2 1 3]\n"," [ 1 278 3 7 11]\n"," [ 0 8 280 6 6]\n"," [ 1 13 5 273 8]\n"," [ 2 8 7 10 273]]\n"]}],"source":["cm = confusion_matrix(true_labels, predicted_labels)\n","print('Confusion matrix:\\n', cm)"]},{"cell_type":"code","execution_count":22,"id":"fda65506","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:45:04.898414Z","iopub.status.busy":"2024-12-07T19:45:04.898065Z","iopub.status.idle":"2024-12-07T19:45:04.909897Z","shell.execute_reply":"2024-12-07T19:45:04.908826Z"},"id":"9db2EW1LPRzc","outputId":"2f0bc5bf-cf72-4972-8656-228d67f6161d","papermill":{"duration":0.715199,"end_time":"2024-12-07T19:45:04.911735","exception":false,"start_time":"2024-12-07T19:45:04.196536","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Classification report:\n"," precision recall f1-score support\n","\n"," tiger 0.99 0.98 0.98 300\n"," lynx 0.91 0.93 0.92 300\n"," bear 0.94 0.93 0.94 300\n"," deer 0.92 0.91 0.91 300\n"," bird 0.91 0.91 0.91 300\n","\n"," accuracy 0.93 1500\n"," macro avg 0.93 0.93 0.93 1500\n","weighted avg 0.93 0.93 0.93 1500\n","\n"]}],"source":["report = classification_report(true_labels, predicted_labels, target_names=class_names)\n","print('Classification report:\\n', report)"]},{"cell_type":"code","execution_count":23,"id":"94c724ab","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:45:06.413173Z","iopub.status.busy":"2024-12-07T19:45:06.412777Z","iopub.status.idle":"2024-12-07T19:45:07.344683Z","shell.execute_reply":"2024-12-07T19:45:07.343747Z"},"id":"oc7KSHEBPV8o","outputId":"d378ad2e-85be-46c7-914c-0bbf12c9ade0","papermill":{"duration":1.688267,"end_time":"2024-12-07T19:45:07.346494","exception":false,"start_time":"2024-12-07T19:45:05.658227","status":"completed"},"tags":[]},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["plt.figure(figsize=(10, 8))\n","sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)\n","plt.xlabel('Predicted')\n","plt.ylabel('True')\n","plt.title('Confusion matrix')\n","confusion_matrix_path = media_dir + '/confusion_matrix.png'\n","plt.savefig(confusion_matrix_path, dpi=300, bbox_inches='tight')\n","plt.show()"]},{"cell_type":"markdown","id":"8bdd587c","metadata":{"papermill":{"duration":0.757235,"end_time":"2024-12-07T19:45:08.801389","exception":false,"start_time":"2024-12-07T19:45:08.044154","status":"completed"},"tags":[]},"source":["## Log run to W&B"]},{"cell_type":"code","execution_count":24,"id":"4b9d3908","metadata":{"execution":{"iopub.execute_input":"2024-12-07T19:45:10.24288Z","iopub.status.busy":"2024-12-07T19:45:10.241978Z","iopub.status.idle":"2024-12-07T19:45:13.85388Z","shell.execute_reply":"2024-12-07T19:45:13.852921Z"},"papermill":{"duration":4.312114,"end_time":"2024-12-07T19:45:13.855502","exception":false,"start_time":"2024-12-07T19:45:09.543388","status":"completed"},"tags":[]},"outputs":[{"name":"stderr","output_type":"stream","text":["\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n","\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33malexvmt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n","\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.18.7\n","\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1m/wandb/run-20241207_194511-18cb8f9u\u001b[0m\n","\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb offline`\u001b[0m to turn off syncing.\n","\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mlively-pond-12\u001b[0m\n","\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://wandb.ai/alexvmt/tiger_classification\u001b[0m\n","\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://wandb.ai/alexvmt/tiger_classification/runs/18cb8f9u\u001b[0m\n","\u001b[34m\u001b[1mwandb\u001b[0m: \n","\u001b[34m\u001b[1mwandb\u001b[0m: \n","\u001b[34m\u001b[1mwandb\u001b[0m: Run history:\n","\u001b[34m\u001b[1mwandb\u001b[0m: f1 ▁\n","\u001b[34m\u001b[1mwandb\u001b[0m: precision ▁\n","\u001b[34m\u001b[1mwandb\u001b[0m: recall ▁\n","\u001b[34m\u001b[1mwandb\u001b[0m: test2_accuracy ▁\n","\u001b[34m\u001b[1mwandb\u001b[0m: test_accuracy ▁\n","\u001b[34m\u001b[1mwandb\u001b[0m: training_time_mins ▁\n","\u001b[34m\u001b[1mwandb\u001b[0m: \n","\u001b[34m\u001b[1mwandb\u001b[0m: Run summary:\n","\u001b[34m\u001b[1mwandb\u001b[0m: f1 0.9321\n","\u001b[34m\u001b[1mwandb\u001b[0m: precision 0.9322\n","\u001b[34m\u001b[1mwandb\u001b[0m: recall 0.932\n","\u001b[34m\u001b[1mwandb\u001b[0m: test2_accuracy 0.868\n","\u001b[34m\u001b[1mwandb\u001b[0m: test_accuracy 0.932\n","\u001b[34m\u001b[1mwandb\u001b[0m: training_time_mins 17.03\n","\u001b[34m\u001b[1mwandb\u001b[0m: \n","\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run \u001b[33mlively-pond-12\u001b[0m at: \u001b[34m\u001b[4mhttps://wandb.ai/alexvmt/tiger_classification/runs/18cb8f9u\u001b[0m\n","\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at: \u001b[34m\u001b[4mhttps://wandb.ai/alexvmt/tiger_classification\u001b[0m\n","\u001b[34m\u001b[1mwandb\u001b[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n","\u001b[34m\u001b[1mwandb\u001b[0m: Find logs at: \u001b[35m\u001b[1m\u001b[0m\n"]}],"source":["if logging:\n"," run = wandb.init(\n"," project='tiger_classification',\n"," config={\n"," 'model_constructor': model_constructor,\n"," 'model_params': model_params,\n"," 'frozen_file_size': frozen_file_size,\n"," 'num_classes': num_classes,\n"," 'n_train_images': n_train_images,\n"," 'n_test_images': n_test_images,\n"," 'batch_size': batch_size,\n"," 'epochs': epochs,\n"," },\n"," )\n"," wandb.log({\n"," 'training_time_mins': training_time_mins,\n"," 'test_accuracy': round(test_accuracy[1], 4),\n"," 'test2_accuracy': round(test2_accuracy[1], 4),\n"," 'precision': round(precision, 4),\n"," 'recall': round(recall, 4),\n"," 'f1': round(f1, 4),\n"," })\n"," wandb.finish()"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":211343195,"sourceType":"kernelVersion"}],"dockerImageVersionId":30805,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"},"papermill":{"default_parameters":{},"duration":1356.582856,"end_time":"2024-12-07T19:45:22.168756","environment_variables":{},"exception":null,"input_path":"__notebook__.ipynb","output_path":"__notebook__.ipynb","parameters":{},"start_time":"2024-12-07T19:22:45.5859","version":"2.6.0"}},"nbformat":4,"nbformat_minor":5} \ No newline at end of file +{"cells":[{"source":"\"Kaggle\"","metadata":{},"cell_type":"markdown"},{"cell_type":"markdown","id":"bc835be3","metadata":{"id":"0goBcwsXEl7q","papermill":{"duration":0.006992,"end_time":"2024-12-07T20:43:12.668766","exception":false,"start_time":"2024-12-07T20:43:12.661774","status":"completed"},"tags":[]},"source":["# Training"]},{"cell_type":"markdown","id":"3a0e368d","metadata":{"id":"_AciCyGkEpkC","papermill":{"duration":0.006929,"end_time":"2024-12-07T20:43:12.681656","exception":false,"start_time":"2024-12-07T20:43:12.674727","status":"completed"},"tags":[]},"source":["## Setup"]},{"cell_type":"code","execution_count":1,"id":"2b7ac84c","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:12.694222Z","iopub.status.busy":"2024-12-07T20:43:12.693883Z","iopub.status.idle":"2024-12-07T20:43:25.016455Z","shell.execute_reply":"2024-12-07T20:43:25.015356Z"},"id":"KVoOmsrBExwK","outputId":"266204f1-044a-49f4-cf32-5ce95a872f9b","papermill":{"duration":12.331731,"end_time":"2024-12-07T20:43:25.018878","exception":false,"start_time":"2024-12-07T20:43:12.687147","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: keras in /opt/conda/lib/python3.10/site-packages (3.3.3)\r\n","Collecting keras\r\n"," Downloading keras-3.7.0-py3-none-any.whl.metadata (5.8 kB)\r\n","Collecting kimm\r\n"," Downloading kimm-0.2.5-py3-none-any.whl.metadata (12 kB)\r\n","Requirement already satisfied: absl-py in /opt/conda/lib/python3.10/site-packages (from keras) (1.4.0)\r\n","Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from keras) (1.26.4)\r\n","Requirement already satisfied: rich in /opt/conda/lib/python3.10/site-packages (from keras) (13.7.1)\r\n","Requirement already satisfied: namex in /opt/conda/lib/python3.10/site-packages (from keras) (0.0.8)\r\n","Requirement already satisfied: h5py in /opt/conda/lib/python3.10/site-packages (from keras) (3.11.0)\r\n","Requirement already satisfied: optree in /opt/conda/lib/python3.10/site-packages (from keras) (0.11.0)\r\n","Requirement already satisfied: ml-dtypes in /opt/conda/lib/python3.10/site-packages (from keras) (0.3.2)\r\n","Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from keras) (21.3)\r\n","Requirement already satisfied: typing-extensions>=4.0.0 in /opt/conda/lib/python3.10/site-packages (from optree->keras) (4.12.2)\r\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->keras) (3.1.2)\r\n","Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.10/site-packages (from rich->keras) (3.0.0)\r\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.10/site-packages (from rich->keras) (2.18.0)\r\n","Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->keras) (0.1.2)\r\n","Downloading keras-3.7.0-py3-none-any.whl (1.2 MB)\r\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n","\u001b[?25hDownloading kimm-0.2.5-py3-none-any.whl (123 kB)\r\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m123.4/123.4 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n","\u001b[?25hInstalling collected packages: keras, kimm\r\n"," Attempting uninstall: keras\r\n"," Found existing installation: keras 3.3.3\r\n"," Uninstalling keras-3.3.3:\r\n"," Successfully uninstalled keras-3.3.3\r\n","Successfully installed keras-3.7.0 kimm-0.2.5\r\n"]}],"source":["!pip install keras kimm -U"]},{"cell_type":"code","execution_count":2,"id":"b4baf665","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:25.034238Z","iopub.status.busy":"2024-12-07T20:43:25.033564Z","iopub.status.idle":"2024-12-07T20:43:25.039887Z","shell.execute_reply":"2024-12-07T20:43:25.03892Z"},"papermill":{"duration":0.015851,"end_time":"2024-12-07T20:43:25.041802","exception":false,"start_time":"2024-12-07T20:43:25.025951","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["/\n"]}],"source":["%cd ../../"]},{"cell_type":"code","execution_count":3,"id":"05626c98","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:25.055758Z","iopub.status.busy":"2024-12-07T20:43:25.055507Z","iopub.status.idle":"2024-12-07T20:43:28.270506Z","shell.execute_reply":"2024-12-07T20:43:28.269376Z"},"id":"4bedI2X2QDdf","papermill":{"duration":3.224493,"end_time":"2024-12-07T20:43:28.272712","exception":false,"start_time":"2024-12-07T20:43:25.048219","status":"completed"},"tags":[]},"outputs":[],"source":["# set seed\n","seed = 42\n","\n","# set num classes and class names\n","num_classes = 5\n","class_names = ['tiger', 'lynx', 'bear', 'deer', 'bird']\n","yaml_file = 'class_list.yaml'\n","\n","# set n train and test images\n","n_train_images = 4000\n","n_test_images = 300\n","\n","# set batch size and epochs\n","batch_size = 16\n","epochs = 30\n","\n","# define paths to train and test images\n","images_input_dir = 'kaggle/input/preprocess-images/images'\n","images_sampled_dir = 'images'\n","!mkdir -p \"$images_sampled_dir\"\n","train_dir = images_input_dir + '/train'\n","train_dir_sampled = images_sampled_dir + '/train_sampled'\n","test_dir = images_input_dir + '/test'\n","test_dir_sampled = images_sampled_dir + '/test_sampled'\n","test2_dir = images_input_dir + '/test2'\n","\n","# define path to model dir\n","model_dir = 'kaggle/working/model'\n","!mkdir -p \"$model_dir\"\n","model_path = model_dir + '/model.h5'\n","model_constructor = 'EfficientNetV2S'\n","\n","# define path to media dir\n","media_dir = 'kaggle/working/media'\n","!mkdir -p \"$media_dir\"\n","\n","# decide whether to log run or not\n","logging = False"]},{"cell_type":"code","execution_count":4,"id":"18634c53","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:28.287265Z","iopub.status.busy":"2024-12-07T20:43:28.286961Z","iopub.status.idle":"2024-12-07T20:43:43.474074Z","shell.execute_reply":"2024-12-07T20:43:43.473036Z"},"id":"WNSWKPrzEzy_","papermill":{"duration":15.197335,"end_time":"2024-12-07T20:43:43.476378","exception":false,"start_time":"2024-12-07T20:43:28.279043","status":"completed"},"tags":[]},"outputs":[],"source":["import os\n","import time\n","import yaml\n","import shutil\n","import random\n","import numpy as np\n","\n","import keras\n","from keras import layers, optimizers, losses, callbacks, saving\n","import kimm\n","import tensorflow as tf\n","import tensorflow_datasets as tfds\n","from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, classification_report\n","\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","\n","import wandb\n","from kaggle_secrets import UserSecretsClient"]},{"cell_type":"code","execution_count":5,"id":"e43d2962","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:43.4905Z","iopub.status.busy":"2024-12-07T20:43:43.490231Z","iopub.status.idle":"2024-12-07T20:43:43.497025Z","shell.execute_reply":"2024-12-07T20:43:43.496319Z"},"id":"RjGj2LDNR15z","papermill":{"duration":0.015275,"end_time":"2024-12-07T20:43:43.498601","exception":false,"start_time":"2024-12-07T20:43:43.483326","status":"completed"},"tags":[]},"outputs":[],"source":["def sample_images(source_dir, target_dir, samples_per_class, seed=42):\n"," \"\"\"\n"," Samples a fixed number of images per class from a directory structure.\n","\n"," Args:\n"," source_dir (str): Path to the source dataset directory.\n"," target_dir (str): Path to the target dataset directory to store sampled data.\n"," samples_per_class (int): Number of images to sample per class.\n"," seed (int): Random seed for reproducibility.\n"," \"\"\"\n"," random.seed(seed)\n","\n"," if not os.path.exists(target_dir):\n"," os.makedirs(target_dir)\n","\n"," for class_name in os.listdir(source_dir):\n"," class_path = os.path.join(source_dir, class_name)\n"," if os.path.isdir(class_path):\n"," sampled_class_dir = os.path.join(target_dir, class_name)\n"," os.makedirs(sampled_class_dir, exist_ok=True)\n","\n"," # list and shuffle all files in class directory\n"," all_images = os.listdir(class_path)\n"," random.shuffle(all_images)\n","\n"," # select desired number of samples\n"," sampled_images = all_images[:samples_per_class]\n","\n"," # copy sampled images to new directory\n"," for image_name in sampled_images:\n"," source_image_path = os.path.join(class_path, image_name)\n"," target_image_path = os.path.join(sampled_class_dir, image_name)\n"," shutil.copy(source_image_path, target_image_path)"]},{"cell_type":"code","execution_count":6,"id":"5a00223c","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:43.512361Z","iopub.status.busy":"2024-12-07T20:43:43.512076Z","iopub.status.idle":"2024-12-07T20:43:43.51786Z","shell.execute_reply":"2024-12-07T20:43:43.516975Z"},"papermill":{"duration":0.014624,"end_time":"2024-12-07T20:43:43.51957","exception":false,"start_time":"2024-12-07T20:43:43.504946","status":"completed"},"tags":[]},"outputs":[],"source":["def create_class_list_yaml_file(num_classes, class_names, file_path):\n"," \"\"\"\n"," Create a YAML file that maps numerical indices to class names.\n","\n"," This function generates a YAML file with a mapping of integer indices \n"," (as strings starting from '1') to the provided class names. The file is \n"," saved to the specified file path, creating any necessary directories along the way.\n","\n"," Args:\n"," num_classes (int): The number of classes. Must match the length of `class_names`.\n"," class_names (list of str): A list of class names to include in the YAML file.\n"," file_path (str): The full file path (including directories and file name) \n"," where the YAML file will be saved.\n"," \"\"\"\n"," if len(class_names) != num_classes:\n"," raise ValueError('The number of class names must match num_classes.')\n","\n"," # ensure the directory exists\n"," directory = os.path.dirname(file_path)\n"," if directory and not os.path.exists(directory):\n"," os.makedirs(directory)\n","\n"," # create a dictionary with the index as keys and class names as values\n"," class_dict = {str(i + 1): class_names[i] for i in range(num_classes)}\n","\n"," # write the dictionary to a YAML file\n"," with open(file_path, 'w') as file:\n"," yaml.dump(class_dict, file, default_flow_style=False)"]},{"cell_type":"code","execution_count":7,"id":"3de89afa","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:43.533664Z","iopub.status.busy":"2024-12-07T20:43:43.533134Z","iopub.status.idle":"2024-12-07T20:43:43.53783Z","shell.execute_reply":"2024-12-07T20:43:43.537106Z"},"papermill":{"duration":0.013655,"end_time":"2024-12-07T20:43:43.539484","exception":false,"start_time":"2024-12-07T20:43:43.525829","status":"completed"},"tags":[]},"outputs":[],"source":["# create class list yaml file which is needed later for deployment\n","yaml_path = model_dir + '/' + yaml_file\n","create_class_list_yaml_file(num_classes, class_names, yaml_path)"]},{"cell_type":"code","execution_count":8,"id":"6db94b2b","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:43.552887Z","iopub.status.busy":"2024-12-07T20:43:43.552605Z","iopub.status.idle":"2024-12-07T20:43:43.556769Z","shell.execute_reply":"2024-12-07T20:43:43.555805Z"},"papermill":{"duration":0.01294,"end_time":"2024-12-07T20:43:43.558384","exception":false,"start_time":"2024-12-07T20:43:43.545444","status":"completed"},"tags":[]},"outputs":[],"source":["# log in to w&b using api key\n","if logging:\n"," user_secrets = UserSecretsClient()\n"," key = user_secrets.get_secret('wandb')\n"," !wandb login $key"]},{"cell_type":"markdown","id":"1e0c6fcc","metadata":{"id":"o-oFY9SuE8dT","papermill":{"duration":0.00609,"end_time":"2024-12-07T20:43:43.570923","exception":false,"start_time":"2024-12-07T20:43:43.564833","status":"completed"},"tags":[]},"source":["## Prepare train and test datasets"]},{"cell_type":"code","execution_count":9,"id":"2f75663a","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:43:43.584575Z","iopub.status.busy":"2024-12-07T20:43:43.584285Z","iopub.status.idle":"2024-12-07T20:46:00.949168Z","shell.execute_reply":"2024-12-07T20:46:00.948217Z"},"id":"KewO2flaRgLu","papermill":{"duration":137.37435,"end_time":"2024-12-07T20:46:00.95147","exception":false,"start_time":"2024-12-07T20:43:43.57712","status":"completed"},"tags":[]},"outputs":[],"source":["# create new directory with sampled train images\n","sample_images(train_dir, train_dir_sampled, n_train_images)"]},{"cell_type":"code","execution_count":10,"id":"cada74ee","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:46:00.967413Z","iopub.status.busy":"2024-12-07T20:46:00.966711Z","iopub.status.idle":"2024-12-07T20:46:12.056292Z","shell.execute_reply":"2024-12-07T20:46:12.055188Z"},"id":"feQggHlSWEh4","papermill":{"duration":11.09987,"end_time":"2024-12-07T20:46:12.058425","exception":false,"start_time":"2024-12-07T20:46:00.958555","status":"completed"},"tags":[]},"outputs":[],"source":["# create new directory with sampled test images\n","sample_images(test_dir, test_dir_sampled, n_test_images)"]},{"cell_type":"code","execution_count":11,"id":"5b43eb82","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:46:12.071919Z","iopub.status.busy":"2024-12-07T20:46:12.071657Z","iopub.status.idle":"2024-12-07T20:46:17.055815Z","shell.execute_reply":"2024-12-07T20:46:17.054666Z"},"id":"J5i-u-VSE74K","outputId":"cda0d93f-4445-41f3-bb1b-5b1942d838b6","papermill":{"duration":4.993142,"end_time":"2024-12-07T20:46:17.057823","exception":false,"start_time":"2024-12-07T20:46:12.064681","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Found 20000 files belonging to 5 classes.\n","Found 1500 files belonging to 5 classes.\n","Found 303 files belonging to 1 classes.\n","Number of train samples: -2\n","Number of test samples: -2\n"]}],"source":["# create train dataset\n","train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n"," train_dir_sampled,\n"," label_mode='categorical',\n"," shuffle=True,\n"," seed=seed,\n",")\n","\n","# create test dataset\n","test_ds = tf.keras.preprocessing.image_dataset_from_directory(\n"," test_dir_sampled,\n"," label_mode='categorical',\n"," shuffle=False,\n",")\n","\n","# create test2 dataset\n","test2_ds = tf.keras.preprocessing.image_dataset_from_directory(\n"," test2_dir,\n"," label_mode='categorical',\n"," shuffle=False,\n",")\n","\n","# we need to unbatch because there's somehow an unwanted additional dimension\n","train_ds = train_ds.unbatch()\n","test_ds = test_ds.unbatch()\n","test2_ds = test2_ds.unbatch()\n","\n","print(f'Number of train samples: {train_ds.cardinality()}')\n","print(f'Number of test samples: {test_ds.cardinality()}')"]},{"cell_type":"code","execution_count":12,"id":"d792fb98","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:46:17.073395Z","iopub.status.busy":"2024-12-07T20:46:17.073065Z","iopub.status.idle":"2024-12-07T20:46:17.077966Z","shell.execute_reply":"2024-12-07T20:46:17.076939Z"},"id":"myqLQZuqFZkx","outputId":"1636afdc-4dcd-4ad9-9041-bedfc196f415","papermill":{"duration":0.015047,"end_time":"2024-12-07T20:46:17.079804","exception":false,"start_time":"2024-12-07T20:46:17.064757","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["(TensorSpec(shape=(256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(5,), dtype=tf.float32, name=None)) (TensorSpec(shape=(256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(5,), dtype=tf.float32, name=None))\n"]}],"source":["# check dimensions\n","print(train_ds.element_spec, test_ds.element_spec)"]},{"cell_type":"code","execution_count":13,"id":"9d363dbb","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:46:17.094914Z","iopub.status.busy":"2024-12-07T20:46:17.09395Z","iopub.status.idle":"2024-12-07T20:46:17.190972Z","shell.execute_reply":"2024-12-07T20:46:17.19022Z"},"id":"YCc38anvFgJc","papermill":{"duration":0.106983,"end_time":"2024-12-07T20:46:17.193179","exception":false,"start_time":"2024-12-07T20:46:17.086196","status":"completed"},"tags":[]},"outputs":[],"source":["# setup dataset with tf.data\n","resize_fn = keras.layers.Resizing(224, 224)\n","\n","train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))\n","test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))\n","test2_ds = test2_ds.map(lambda x, y: (resize_fn(x), y))\n","\n","train_ds = train_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()\n","test_ds = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()\n","test2_ds = test2_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()"]},{"cell_type":"markdown","id":"2604dc32","metadata":{"id":"CdVE1QLMFtA5","papermill":{"duration":0.006709,"end_time":"2024-12-07T20:46:17.207251","exception":false,"start_time":"2024-12-07T20:46:17.200542","status":"completed"},"tags":[]},"source":["## Prepare model"]},{"cell_type":"code","execution_count":14,"id":"77ed5544","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:46:17.222726Z","iopub.status.busy":"2024-12-07T20:46:17.222396Z","iopub.status.idle":"2024-12-07T20:46:36.854257Z","shell.execute_reply":"2024-12-07T20:46:36.853244Z"},"id":"3mmwlJBDFf_6","outputId":"73087a5a-85e9-43f0-bdee-b067a8168bcb","papermill":{"duration":19.641991,"end_time":"2024-12-07T20:46:36.856279","exception":false,"start_time":"2024-12-07T20:46:17.214288","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Downloading data from https://github.com/james77777778/keras-image-models/releases/download/0.1.0/efficientnetv2s_tf_efficientnetv2_s.in21k_ft_in1k.keras\n","\u001b[1m87666342/87666342\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 0us/step\n"]},{"data":{"text/html":["
Model: \"functional\"\n","
\n"],"text/plain":["\u001b[1mModel: \"functional\"\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓\n","┃ Layer (type)                 Output Shape              Param #  Trai… ┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩\n","│ input_layer_1 (InputLayer)  │ (None, 224, 224, 3)   │          0-   │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ EfficientNetV2S             │ (None, 7, 7, 1280)    │ 20,331,360N   │\n","│ (EfficientNetV2S)           │                       │            │       │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ global_average_pooling2d    │ (None, 1280)          │          0-   │\n","│ (GlobalAveragePooling2D)    │                       │            │       │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ dropout (Dropout)           │ (None, 1280)          │          0-   │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ dense (Dense)               │ (None, 5)             │      6,405Y   │\n","└─────────────────────────────┴───────────────────────┴────────────┴───────┘\n","
\n"],"text/plain":["┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓\n","┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mTrai…\u001b[0m\u001b[1m \u001b[0m┃\n","┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩\n","│ input_layer_1 (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ \u001b[1m-\u001b[0m │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ EfficientNetV2S │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m1280\u001b[0m) │ \u001b[38;5;34m20,331,360\u001b[0m │ \u001b[1;91mN\u001b[0m │\n","│ (\u001b[38;5;33mEfficientNetV2S\u001b[0m) │ │ │ │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ global_average_pooling2d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1280\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ \u001b[1m-\u001b[0m │\n","│ (\u001b[38;5;33mGlobalAveragePooling2D\u001b[0m) │ │ │ │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1280\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ \u001b[1m-\u001b[0m │\n","├─────────────────────────────┼───────────────────────┼────────────┼───────┤\n","│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m5\u001b[0m) │ \u001b[38;5;34m6,405\u001b[0m │ \u001b[1;38;5;34mY\u001b[0m │\n","└─────────────────────────────┴───────────────────────┴────────────┴───────┘\n"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["
 Total params: 20,337,765 (77.58 MB)\n","
\n"],"text/plain":["\u001b[1m Total params: \u001b[0m\u001b[38;5;34m20,337,765\u001b[0m (77.58 MB)\n"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["
 Trainable params: 6,405 (25.02 KB)\n","
\n"],"text/plain":["\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m6,405\u001b[0m (25.02 KB)\n"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["
 Non-trainable params: 20,331,360 (77.56 MB)\n","
\n"],"text/plain":["\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20,331,360\u001b[0m (77.56 MB)\n"]},"metadata":{},"output_type":"display_data"}],"source":["# create base model\n","if model_constructor == 'EfficientNetV2B0':\n"," base_model = kimm.models.EfficientNetV2B0(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '5 mio'\n"," frozen_file_size = '26 mb'\n","elif model_constructor == 'EfficientNetV2B2':\n"," base_model = kimm.models.EfficientNetV2B2(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '9 mio'\n"," frozen_file_size = '37 mb'\n","elif model_constructor == 'EfficientNetV2S':\n"," base_model = kimm.models.EfficientNetV2S(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '21 mio'\n"," frozen_file_size = '84 mb'\n","elif model_constructor == 'EfficientNetV2M':\n"," base_model = kimm.models.EfficientNetV2M(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '54 mio'\n"," frozen_file_size = '216 mb'\n","elif model_constructor == 'EfficientNetV2L':\n"," base_model = kimm.models.EfficientNetV2L(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '119 mio'\n"," frozen_file_size = '475 mb'\n","elif model_constructor == 'EfficientNetV2XL':\n"," base_model = kimm.models.EfficientNetV2XL(\n"," input_shape=(224, 224, 3),\n"," include_preprocessing=True,\n"," include_top=False,\n"," )\n"," model_params = '208 mio'\n"," frozen_file_size = '835 mb'\n","else:\n"," raise Exception('Please select a valid model constructor.') \n","\n","# freeze base model\n","base_model.trainable = False\n","\n","# create new model on top\n","inputs = keras.Input(shape=(224, 224, 3))\n","x = inputs\n","\n","# The base model contains batchnorm layers. We want to keep them in inference mode\n","# when we unfreeze the base model for fine-tuning, so we make sure that the\n","# base_model is running in inference mode here.\n","x = base_model(x, training=False)\n","x = keras.layers.GlobalAveragePooling2D()(x)\n","x = keras.layers.Dropout(0.2)(x) # regularize with dropout\n","outputs = keras.layers.Dense(num_classes, activation='softmax')(x)\n","model = keras.Model(inputs, outputs)\n","\n","model.summary(show_trainable=True)"]},{"cell_type":"markdown","id":"fd0a52c2","metadata":{"id":"YcYVdF_1F9EQ","papermill":{"duration":0.008162,"end_time":"2024-12-07T20:46:36.873383","exception":false,"start_time":"2024-12-07T20:46:36.865221","status":"completed"},"tags":[]},"source":["## Training\n","\n","Follow [mewc-train](https://github.com/zaandahl/mewc-train)"]},{"cell_type":"code","execution_count":15,"id":"3d9af3dd","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:46:36.891185Z","iopub.status.busy":"2024-12-07T20:46:36.890335Z","iopub.status.idle":"2024-12-07T20:46:36.900936Z","shell.execute_reply":"2024-12-07T20:46:36.900228Z"},"id":"Ho-88Fe_GxVd","papermill":{"duration":0.021364,"end_time":"2024-12-07T20:46:36.902614","exception":false,"start_time":"2024-12-07T20:46:36.88125","status":"completed"},"tags":[]},"outputs":[],"source":["df_size = int(n_train_images * num_classes)\n","\n","lr_init = 1e-4\n","min_lr_frac = 1/5 # default minimum learning rate fraction of initial learning rate\n","steps_per_epoch = df_size // batch_size\n","total_steps = epochs * steps_per_epoch # total number of steps for monotonic exponential decay across all epochs\n","lr = optimizers.schedules.ExponentialDecay(initial_learning_rate=lr_init, decay_steps=total_steps, decay_rate=min_lr_frac, staircase=False)\n","amsgrad = True\n","weight_decay = 1e-4\n","optimizer = optimizers.AdamW(learning_rate=lr, amsgrad=amsgrad, weight_decay=weight_decay)\n","\n","if num_classes == 2:\n"," loss_f = losses.BinaryFocalCrossentropy() # use for binary classification tasks\n"," act_f = 'sigmoid' # use for binary classification tasks\n","else:\n"," loss_f = losses.CategoricalFocalCrossentropy() # use for unbalanced multi-class tasks (typical for wildlife datasets)\n"," act_f = 'softmax' # use for multi-class classification tasks\n","\n","metrics = ['accuracy']\n","\n","callbacks = [callbacks.EarlyStopping(monitor='loss', mode='min', min_delta=0.001, patience=5, restore_best_weights=True)]"]},{"cell_type":"code","execution_count":16,"id":"9674fe77","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:46:36.920071Z","iopub.status.busy":"2024-12-07T20:46:36.919753Z","iopub.status.idle":"2024-12-07T20:46:36.93005Z","shell.execute_reply":"2024-12-07T20:46:36.929353Z"},"id":"fAHCE7OrGCGF","papermill":{"duration":0.021003,"end_time":"2024-12-07T20:46:36.931892","exception":false,"start_time":"2024-12-07T20:46:36.910889","status":"completed"},"tags":[]},"outputs":[],"source":["model.compile(\n"," optimizer=optimizer,\n"," loss=loss_f,\n"," metrics=metrics,\n",")"]},{"cell_type":"code","execution_count":17,"id":"9ff7d098","metadata":{"execution":{"iopub.execute_input":"2024-12-07T20:46:36.949224Z","iopub.status.busy":"2024-12-07T20:46:36.9489Z","iopub.status.idle":"2024-12-07T21:04:57.659925Z","shell.execute_reply":"2024-12-07T21:04:57.658896Z"},"id":"GZKd2EixF8yO","outputId":"e8424e62-1116-4b52-90af-5fe25403927f","papermill":{"duration":1100.721879,"end_time":"2024-12-07T21:04:57.66185","exception":false,"start_time":"2024-12-07T20:46:36.939971","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Epoch 1/30\n"]},{"name":"stderr","output_type":"stream","text":["WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n","I0000 00:00:1733604413.541465 84 service.cc:145] XLA service 0x7c3668001be0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n","I0000 00:00:1733604413.541536 84 service.cc:153] StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0\n"]},{"name":"stdout","output_type":"stream","text":[" 3/Unknown \u001b[1m27s\u001b[0m 42ms/step - accuracy: 0.2188 - loss: 0.3537"]},{"name":"stderr","output_type":"stream","text":["I0000 00:00:1733604424.214460 84 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n"]},{"name":"stdout","output_type":"stream","text":[" 1250/Unknown \u001b[1m90s\u001b[0m 51ms/step - accuracy: 0.6078 - loss: 0.1701"]},{"name":"stderr","output_type":"stream","text":["/opt/conda/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py:151: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.\n"," self._interrupted_warning()\n"]},{"name":"stdout","output_type":"stream","text":["\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m91s\u001b[0m 51ms/step - accuracy: 0.6079 - loss: 0.1700\n","Epoch 2/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.8546 - loss: 0.0580\n","Epoch 3/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.8789 - loss: 0.0448\n","Epoch 4/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.8899 - loss: 0.0396\n","Epoch 5/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9010 - loss: 0.0362\n","Epoch 6/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9080 - loss: 0.0329\n","Epoch 7/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9106 - loss: 0.0308\n","Epoch 8/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9173 - loss: 0.0293\n","Epoch 9/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9180 - loss: 0.0278\n","Epoch 10/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m43s\u001b[0m 35ms/step - accuracy: 0.9212 - loss: 0.0272\n","Epoch 11/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9236 - loss: 0.0261\n","Epoch 12/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9213 - loss: 0.0267\n","Epoch 13/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 36ms/step - accuracy: 0.9290 - loss: 0.0244\n","Epoch 14/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9240 - loss: 0.0249\n","Epoch 15/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9334 - loss: 0.0232\n","Epoch 16/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9302 - loss: 0.0233\n","Epoch 17/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9300 - loss: 0.0240\n","Epoch 18/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9330 - loss: 0.0225\n","Epoch 19/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9347 - loss: 0.0218\n","Epoch 20/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9325 - loss: 0.0225\n","Epoch 21/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9329 - loss: 0.0220\n","Epoch 22/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 35ms/step - accuracy: 0.9359 - loss: 0.0217\n","Epoch 23/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m43s\u001b[0m 35ms/step - accuracy: 0.9355 - loss: 0.0221\n","Epoch 24/30\n","\u001b[1m1250/1250\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m43s\u001b[0m 34ms/step - accuracy: 0.9335 - loss: 0.0218\n","Training time mins: 18.35\n"]}],"source":["start_time = time.time()\n","model.fit(train_ds, epochs=epochs, callbacks=callbacks)\n","end_time = time.time()\n","training_time_mins = round((end_time - start_time) / 60, 2)\n","print(f'Training time mins: {training_time_mins}')"]},{"cell_type":"code","execution_count":18,"id":"c60025ac","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:04:59.171125Z","iopub.status.busy":"2024-12-07T21:04:59.170754Z","iopub.status.idle":"2024-12-07T21:04:59.693603Z","shell.execute_reply":"2024-12-07T21:04:59.692839Z"},"id":"ik4KqmJdUbKd","outputId":"d168634b-b029-4bf1-cf87-e280ace945fb","papermill":{"duration":1.306508,"end_time":"2024-12-07T21:04:59.695598","exception":false,"start_time":"2024-12-07T21:04:58.38909","status":"completed"},"tags":[]},"outputs":[],"source":["saving.save_model(model, model_path)"]},{"cell_type":"markdown","id":"38d78f4c","metadata":{"id":"qIGU25KNO1sp","papermill":{"duration":0.733873,"end_time":"2024-12-07T21:05:01.220996","exception":false,"start_time":"2024-12-07T21:05:00.487123","status":"completed"},"tags":[]},"source":["## Evaluation"]},{"cell_type":"code","execution_count":19,"id":"b146e98f","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:05:02.765595Z","iopub.status.busy":"2024-12-07T21:05:02.765206Z","iopub.status.idle":"2024-12-07T21:05:19.074074Z","shell.execute_reply":"2024-12-07T21:05:19.072447Z"},"papermill":{"duration":17.086515,"end_time":"2024-12-07T21:05:19.076192","exception":false,"start_time":"2024-12-07T21:05:01.989677","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[1m94/94\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 97ms/step - accuracy: 0.9544 - loss: 0.0154\n","Test accuracy: 93.20%\n"]}],"source":["test_accuracy = model.evaluate(test_ds)\n","print(f'Test accuracy: {test_accuracy[1] * 100:3.2f}%')"]},{"cell_type":"code","execution_count":20,"id":"9dee9061","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:05:20.580226Z","iopub.status.busy":"2024-12-07T21:05:20.579867Z","iopub.status.idle":"2024-12-07T21:05:33.525015Z","shell.execute_reply":"2024-12-07T21:05:33.523948Z"},"id":"zyDqRJSEGM4L","outputId":"09e78303-ee89-418f-cef9-f1a4900de3f2","papermill":{"duration":13.722087,"end_time":"2024-12-07T21:05:33.52678","exception":false,"start_time":"2024-12-07T21:05:19.804693","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[1m19/19\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 307ms/step - accuracy: 0.8592 - loss: 3.2432\n","Test2 accuracy: 86.80%\n"]}],"source":["test2_accuracy = model.evaluate(test2_ds)\n","print(f'Test2 accuracy: {test2_accuracy[1] * 100:3.2f}%')"]},{"cell_type":"code","execution_count":21,"id":"4add1328","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:05:35.032328Z","iopub.status.busy":"2024-12-07T21:05:35.031351Z","iopub.status.idle":"2024-12-07T21:05:59.609545Z","shell.execute_reply":"2024-12-07T21:05:59.60868Z"},"id":"hpMkylclPBAB","outputId":"909e4490-b5b4-48ff-a2a6-71f42226d86e","papermill":{"duration":25.301667,"end_time":"2024-12-07T21:05:59.611203","exception":false,"start_time":"2024-12-07T21:05:34.309536","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 7s/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 66ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 76ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 75ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 79ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 78ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 79ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 77ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 75ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 76ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 83ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 79ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 76ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 76ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 75ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 68ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 76ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 77ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 75ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 77ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 72ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 74ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 71ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 73ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 76ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 69ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 70ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 67ms/step\n","\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 7s/step\n"]}],"source":["true_labels = []\n","predicted_labels = []\n","\n","for images, labels in test_ds:\n"," # append true labels based on their format\n"," if len(labels.shape) > 1: # if one-hot encoded\n"," true_labels.append(np.argmax(labels.numpy(), axis=1))\n"," else: # if integer labels\n"," true_labels.append(labels.numpy())\n","\n"," # predict labels\n"," predictions = model.predict(images)\n"," predicted_labels.append(np.argmax(predictions, axis=1))\n","\n","# combine all batches into single arrays\n","true_labels = np.concatenate(true_labels)\n","predicted_labels = np.concatenate(predicted_labels)"]},{"cell_type":"code","execution_count":22,"id":"7a7a8fe7","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:06:01.178876Z","iopub.status.busy":"2024-12-07T21:06:01.178504Z","iopub.status.idle":"2024-12-07T21:06:01.19184Z","shell.execute_reply":"2024-12-07T21:06:01.190693Z"},"papermill":{"duration":0.801657,"end_time":"2024-12-07T21:06:01.193789","exception":false,"start_time":"2024-12-07T21:06:00.392132","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Precision: 93.22%; Recall: 93.20%; F1: 93.21%\n"]}],"source":["precision = precision_score(true_labels, predicted_labels, average='macro')\n","recall = recall_score(true_labels, predicted_labels, average='macro')\n","f1 = f1_score(true_labels, predicted_labels, average='macro')\n","print(f'Precision: {precision * 100:3.2f}%; Recall: {recall * 100:3.2f}%; F1: {f1 * 100:3.2f}%')"]},{"cell_type":"code","execution_count":23,"id":"827ba6f5","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:06:02.721673Z","iopub.status.busy":"2024-12-07T21:06:02.72133Z","iopub.status.idle":"2024-12-07T21:06:02.728735Z","shell.execute_reply":"2024-12-07T21:06:02.727645Z"},"id":"vYjCERt3PNQX","outputId":"a6eaab6b-1891-44a3-a150-8ef7a86474fc","papermill":{"duration":0.746514,"end_time":"2024-12-07T21:06:02.730425","exception":false,"start_time":"2024-12-07T21:06:01.983911","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Confusion matrix:\n"," [[294 0 2 1 3]\n"," [ 1 278 3 7 11]\n"," [ 0 8 280 6 6]\n"," [ 1 13 5 273 8]\n"," [ 2 8 7 10 273]]\n"]}],"source":["cm = confusion_matrix(true_labels, predicted_labels)\n","print('Confusion matrix:\\n', cm)"]},{"cell_type":"code","execution_count":24,"id":"96363a39","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:06:04.272219Z","iopub.status.busy":"2024-12-07T21:06:04.271861Z","iopub.status.idle":"2024-12-07T21:06:04.283806Z","shell.execute_reply":"2024-12-07T21:06:04.282814Z"},"id":"9db2EW1LPRzc","outputId":"2f0bc5bf-cf72-4972-8656-228d67f6161d","papermill":{"duration":0.746301,"end_time":"2024-12-07T21:06:04.285528","exception":false,"start_time":"2024-12-07T21:06:03.539227","status":"completed"},"tags":[]},"outputs":[{"name":"stdout","output_type":"stream","text":["Classification report:\n"," precision recall f1-score support\n","\n"," tiger 0.99 0.98 0.98 300\n"," lynx 0.91 0.93 0.92 300\n"," bear 0.94 0.93 0.94 300\n"," deer 0.92 0.91 0.91 300\n"," bird 0.91 0.91 0.91 300\n","\n"," accuracy 0.93 1500\n"," macro avg 0.93 0.93 0.93 1500\n","weighted avg 0.93 0.93 0.93 1500\n","\n"]}],"source":["report = classification_report(true_labels, predicted_labels, target_names=class_names)\n","print('Classification report:\\n', report)"]},{"cell_type":"code","execution_count":25,"id":"93564034","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:06:05.854Z","iopub.status.busy":"2024-12-07T21:06:05.853629Z","iopub.status.idle":"2024-12-07T21:06:06.829963Z","shell.execute_reply":"2024-12-07T21:06:06.829004Z"},"id":"oc7KSHEBPV8o","outputId":"d378ad2e-85be-46c7-914c-0bbf12c9ade0","papermill":{"duration":1.760316,"end_time":"2024-12-07T21:06:06.831607","exception":false,"start_time":"2024-12-07T21:06:05.071291","status":"completed"},"tags":[]},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["plt.figure(figsize=(10, 8))\n","sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)\n","plt.xlabel('Predicted')\n","plt.ylabel('True')\n","plt.title('Confusion matrix')\n","confusion_matrix_path = media_dir + '/confusion_matrix.png'\n","plt.savefig(confusion_matrix_path, dpi=300, bbox_inches='tight')\n","plt.show()"]},{"cell_type":"markdown","id":"12272ea3","metadata":{"papermill":{"duration":0.78006,"end_time":"2024-12-07T21:06:08.382629","exception":false,"start_time":"2024-12-07T21:06:07.602569","status":"completed"},"tags":[]},"source":["## Log run to W&B"]},{"cell_type":"code","execution_count":26,"id":"803bd0ea","metadata":{"execution":{"iopub.execute_input":"2024-12-07T21:06:09.901524Z","iopub.status.busy":"2024-12-07T21:06:09.900428Z","iopub.status.idle":"2024-12-07T21:06:09.910141Z","shell.execute_reply":"2024-12-07T21:06:09.909221Z"},"papermill":{"duration":0.748173,"end_time":"2024-12-07T21:06:09.912114","exception":false,"start_time":"2024-12-07T21:06:09.163941","status":"completed"},"tags":[]},"outputs":[],"source":["if logging:\n"," run = wandb.init(\n"," project='tiger_classification',\n"," config={\n"," 'model_constructor': model_constructor,\n"," 'model_params': model_params,\n"," 'frozen_file_size': frozen_file_size,\n"," 'num_classes': num_classes,\n"," 'n_train_images': n_train_images,\n"," 'n_test_images': n_test_images,\n"," 'batch_size': batch_size,\n"," 'epochs': epochs,\n"," },\n"," )\n"," wandb.log({\n"," 'training_time_mins': training_time_mins,\n"," 'test_accuracy': round(test_accuracy[1], 4),\n"," 'test2_accuracy': round(test2_accuracy[1], 4),\n"," 'precision': round(precision, 4),\n"," 'recall': round(recall, 4),\n"," 'f1': round(f1, 4),\n"," })\n"," wandb.finish()"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":211343195,"sourceType":"kernelVersion"}],"dockerImageVersionId":30805,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"},"papermill":{"default_parameters":{},"duration":1388.685544,"end_time":"2024-12-07T21:06:18.776939","environment_variables":{},"exception":null,"input_path":"__notebook__.ipynb","output_path":"__notebook__.ipynb","parameters":{},"start_time":"2024-12-07T20:43:10.091395","version":"2.6.0"}},"nbformat":4,"nbformat_minor":5} \ No newline at end of file