From 14b91aba088701306792ddbc42dc0827dbf483fd Mon Sep 17 00:00:00 2001 From: fcakyon <34196005+fcakyon@users.noreply.github.com> Date: Tue, 13 Jul 2021 18:58:52 +0300 Subject: [PATCH] refactor predict api (#170) * refactor predict api * update notebooks --- demo/inference_for_mmdetection.ipynb | 276 ++++++++++++++++----------- demo/inference_for_yolov5.ipynb | 158 +++++++++------ sahi/model.py | 8 +- sahi/predict.py | 113 ++++++----- scripts/predict.py | 22 +-- scripts/predict_fiftyone.py | 22 +-- tests/test_mmdetectionmodel.py | 14 +- tests/test_predict.py | 42 ++-- tests/test_yolov5model.py | 6 +- 9 files changed, 372 insertions(+), 289 deletions(-) diff --git a/demo/inference_for_mmdetection.ipynb b/demo/inference_for_mmdetection.ipynb index d0c1b155d..deda7a7c1 100644 --- a/demo/inference_for_mmdetection.ipynb +++ b/demo/inference_for_mmdetection.ipynb @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -129,7 +129,7 @@ "detection_model = MmdetDetectionModel(\n", " model_path=model_path,\n", " config_path=config_path,\n", - " prediction_score_threshold=0.4,\n", + " confidence_threshold=0.4,\n", " device=\"cpu\", # or 'cuda'\n", ")" ] @@ -143,9 +143,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fatihakyon/miniconda3/envs/sahi/lib/python3.8/site-packages/mmdet/datasets/utils.py:64: UserWarning: \"ImageToTensor\" pipeline is replaced by \"DefaultFormatBundle\" for batch inference. It is recommended to manually replace it in the test data pipeline in your config file.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "result = get_prediction(\"demo_data/small-vehicles1.jpeg\", detection_model)" ] @@ -159,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -211,14 +220,14 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Number of slices: 18\n" + "Number of slices: 15\n" ] } ], @@ -278,7 +287,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -287,20 +296,20 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ObjectPrediction<\n", - " bbox: BoundingBox: <(449, 310, 493, 339), w: 44, h: 29>,\n", - " mask: ,\n", - " score: PredictionScore: ,\n", + " bbox: BoundingBox: <(448, 310, 494, 340), w: 46, h: 30>,\n", + " mask: ,\n", + " score: PredictionScore: ,\n", " category: Category: >" ] }, - "execution_count": 14, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -318,15 +327,15 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'image_id': None,\n", - " 'bbox': [449, 310, 44, 29],\n", - " 'score': 0.9974353909492493,\n", + " 'bbox': [448, 310, 46, 30],\n", + " 'score': 0.9974352717399597,\n", " 'category_id': 2,\n", " 'category_name': 'car',\n", " 'segmentation': [[465,\n", @@ -335,8 +344,10 @@ " 311,\n", " 460,\n", " 311,\n", + " 459,\n", + " 312,\n", " 458,\n", - " 313,\n", + " 312,\n", " 457,\n", " 313,\n", " 457,\n", @@ -345,62 +356,80 @@ " 316,\n", " 455,\n", " 317,\n", - " 453,\n", - " 319,\n", - " 453,\n", + " 452,\n", " 320,\n", - " 451,\n", - " 322,\n", - " 451,\n", + " 452,\n", + " 321,\n", + " 450,\n", " 323,\n", " 450,\n", " 324,\n", - " 450,\n", + " 449,\n", " 325,\n", " 449,\n", - " 326,\n", + " 329,\n", + " 448,\n", + " 330,\n", + " 448,\n", + " 334,\n", + " 449,\n", + " 335,\n", " 449,\n", " 338,\n", " 450,\n", " 339,\n", - " 457,\n", + " 451,\n", + " 339,\n", + " 452,\n", + " 340,\n", + " 453,\n", + " 340,\n", + " 454,\n", " 339,\n", " 458,\n", + " 339,\n", + " 459,\n", " 338,\n", - " 460,\n", + " 466,\n", " 338,\n", - " 461,\n", + " 467,\n", " 337,\n", - " 479,\n", + " 471,\n", " 337,\n", - " 480,\n", + " 472,\n", " 338,\n", " 481,\n", " 338,\n", " 482,\n", " 339,\n", - " 484,\n", + " 483,\n", " 339,\n", - " 488,\n", + " 484,\n", + " 340,\n", + " 487,\n", + " 340,\n", + " 492,\n", " 335,\n", - " 489,\n", + " 493,\n", " 335,\n", - " 491,\n", - " 333,\n", - " 492,\n", + " 493,\n", + " 334,\n", + " 494,\n", " 333,\n", + " 494,\n", + " 321,\n", " 493,\n", - " 332,\n", + " 320,\n", " 493,\n", " 319,\n", - " 491,\n", - " 317,\n", " 490,\n", - " 317,\n", - " 488,\n", - " 315,\n", - " 487,\n", - " 315,\n", + " 316,\n", + " 489,\n", + " 316,\n", + " 486,\n", + " 313,\n", + " 485,\n", + " 313,\n", " 484,\n", " 312,\n", " 483,\n", @@ -412,10 +441,10 @@ " 476,\n", " 310]],\n", " 'iscrowd': 0,\n", - " 'area': 1050}" + " 'area': 1118}" ] }, - "execution_count": 16, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -433,15 +462,15 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'image_id': 1,\n", - " 'bbox': [449, 310, 44, 29],\n", - " 'score': 0.9974353909492493,\n", + " 'bbox': [448, 310, 46, 30],\n", + " 'score': 0.9974352717399597,\n", " 'category_id': 2,\n", " 'category_name': 'car',\n", " 'segmentation': [[465,\n", @@ -450,8 +479,10 @@ " 311,\n", " 460,\n", " 311,\n", + " 459,\n", + " 312,\n", " 458,\n", - " 313,\n", + " 312,\n", " 457,\n", " 313,\n", " 457,\n", @@ -460,62 +491,80 @@ " 316,\n", " 455,\n", " 317,\n", - " 453,\n", - " 319,\n", - " 453,\n", + " 452,\n", " 320,\n", - " 451,\n", - " 322,\n", - " 451,\n", + " 452,\n", + " 321,\n", + " 450,\n", " 323,\n", " 450,\n", " 324,\n", - " 450,\n", + " 449,\n", " 325,\n", " 449,\n", - " 326,\n", + " 329,\n", + " 448,\n", + " 330,\n", + " 448,\n", + " 334,\n", + " 449,\n", + " 335,\n", " 449,\n", " 338,\n", " 450,\n", " 339,\n", - " 457,\n", + " 451,\n", + " 339,\n", + " 452,\n", + " 340,\n", + " 453,\n", + " 340,\n", + " 454,\n", " 339,\n", " 458,\n", + " 339,\n", + " 459,\n", " 338,\n", - " 460,\n", + " 466,\n", " 338,\n", - " 461,\n", + " 467,\n", " 337,\n", - " 479,\n", + " 471,\n", " 337,\n", - " 480,\n", + " 472,\n", " 338,\n", " 481,\n", " 338,\n", " 482,\n", " 339,\n", - " 484,\n", + " 483,\n", " 339,\n", - " 488,\n", + " 484,\n", + " 340,\n", + " 487,\n", + " 340,\n", + " 492,\n", " 335,\n", - " 489,\n", + " 493,\n", " 335,\n", - " 491,\n", - " 333,\n", - " 492,\n", + " 493,\n", + " 334,\n", + " 494,\n", " 333,\n", + " 494,\n", + " 321,\n", " 493,\n", - " 332,\n", + " 320,\n", " 493,\n", " 319,\n", - " 491,\n", - " 317,\n", " 490,\n", - " 317,\n", - " 488,\n", - " 315,\n", - " 487,\n", - " 315,\n", + " 316,\n", + " 489,\n", + " 316,\n", + " 486,\n", + " 313,\n", + " 485,\n", + " 313,\n", " 484,\n", " 312,\n", " 483,\n", @@ -527,10 +576,10 @@ " 476,\n", " 310]],\n", " 'iscrowd': 0,\n", - " 'area': 1050}" + " 'area': 1118}" ] }, - "execution_count": 17, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -548,16 +597,16 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 18, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -575,30 +624,30 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, - "execution_count": 19, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -623,20 +672,16 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "model_name = \"MmdetDetectionModel\"\n", - "model_parameters = {\n", - " \"model_path\": model_path,\n", - " \"config_path\": config_path,\n", - " \"device\": \"cpu\", # or 'cuda'\n", - " \"prediction_score_threshold\":0.4,\n", - " \"category_mapping\": None,\n", - " \"category_remapping\": None,\n", - "}\n", - "apply_sliced_prediction = True\n", + "model_type = \"mmdet\"\n", + "model_path = model_path\n", + "model_config_path = config_path\n", + "model_device = \"cpu\" # or 'cuda'\n", + "model_confidence_threshold = 0.4\n", + "\n", "slice_height = 256\n", "slice_width = 256\n", "overlap_height_ratio = 0.2\n", @@ -654,7 +699,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -669,7 +714,8 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/2 [00:00,\n", + " bbox: BoundingBox: <(447, 308, 496, 342), w: 49, h: 34>,\n", " mask: None,\n", - " score: PredictionScore: ,\n", + " score: PredictionScore: ,\n", " category: Category: >" ] }, - "execution_count": 7, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -301,23 +301,23 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'image_id': None,\n", - " 'bbox': [447, 308, 49, 33],\n", - " 'score': 0.9154346585273743,\n", + " 'bbox': [447, 308, 49, 34],\n", + " 'score': 0.91552734375,\n", " 'category_id': 2,\n", " 'category_name': 'car',\n", " 'segmentation': [],\n", " 'iscrowd': 0,\n", - " 'area': 1617},\n", + " 'area': 1666},\n", " {'image_id': None,\n", " 'bbox': [321, 321, 62, 41],\n", - " 'score': 0.887986958026886,\n", + " 'score': 0.8876953125,\n", " 'category_id': 2,\n", " 'category_name': 'car',\n", " 'segmentation': [],\n", @@ -325,7 +325,7 @@ " 'area': 2542},\n", " {'image_id': None,\n", " 'bbox': [382, 278, 37, 26],\n", - " 'score': 0.8796938061714172,\n", + " 'score': 0.8798828125,\n", " 'category_id': 2,\n", " 'category_name': 'car',\n", " 'segmentation': [],\n", @@ -333,7 +333,7 @@ " 'area': 962}]" ] }, - "execution_count": 9, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -351,23 +351,23 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'image_id': 1,\n", - " 'bbox': [447, 308, 49, 33],\n", - " 'score': 0.9154346585273743,\n", + " 'bbox': [447, 308, 49, 34],\n", + " 'score': 0.91552734375,\n", " 'category_id': 2,\n", " 'category_name': 'car',\n", " 'segmentation': [],\n", " 'iscrowd': 0,\n", - " 'area': 1617},\n", + " 'area': 1666},\n", " {'image_id': 1,\n", " 'bbox': [321, 321, 62, 41],\n", - " 'score': 0.887986958026886,\n", + " 'score': 0.8876953125,\n", " 'category_id': 2,\n", " 'category_name': 'car',\n", " 'segmentation': [],\n", @@ -375,7 +375,7 @@ " 'area': 2542},\n", " {'image_id': 1,\n", " 'bbox': [382, 278, 37, 26],\n", - " 'score': 0.8796938061714172,\n", + " 'score': 0.8798828125,\n", " 'category_id': 2,\n", " 'category_name': 'car',\n", " 'segmentation': [],\n", @@ -383,7 +383,7 @@ " 'area': 962}]" ] }, - "execution_count": 10, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -401,18 +401,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[,\n", - " ,\n", - " ]" + "[,\n", + " ,\n", + " ]" ] }, - "execution_count": 7, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -430,14 +430,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[,\n", " ,\n", " ]" ] }, - "execution_count": 8, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -508,17 +508,15 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ - "model_name = \"Yolov5DetectionModel\"\n", - "model_parameters = {\n", - " \"model_path\": yolov5_model_path,\n", - " \"device\": \"cpu\", # or 'cuda'\n", - " \"prediction_score_threshold\":0.4,\n", - "}\n", - "apply_sliced_prediction = True\n", + "model_type = \"yolov5\"\n", + "model_path = yolov5_model_path\n", + "model_device = \"cpu\" # or 'cuda'\n", + "model_confidence_threshold = 0.4\n", + "\n", "slice_height = 256\n", "slice_width = 256\n", "overlap_height_ratio = 0.2\n", @@ -536,14 +534,53 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "There are 2 listed files in folder .\n" + "There are 2 listed files in folder .\n", + "\n", + " from n params module arguments \n", + " 0 -1 1 3520 yolov5.models.common.Focus [3, 32, 3] \n", + " 1 -1 1 18560 yolov5.models.common.Conv [32, 64, 3, 2] \n", + " 2 -1 1 18816 yolov5.models.common.C3 [64, 64, 1] \n", + " 3 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", + " 4 -1 1 156928 yolov5.models.common.C3 [128, 128, 3] \n", + " 5 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", + " 6 -1 1 625152 yolov5.models.common.C3 [256, 256, 3] \n", + " 7 -1 1 885504 yolov5.models.common.Conv [256, 384, 3, 2] \n", + " 8 -1 1 665856 yolov5.models.common.C3 [384, 384, 1] \n", + " 9 -1 1 1770496 yolov5.models.common.Conv [384, 512, 3, 2] \n", + " 10 -1 1 656896 yolov5.models.common.SPP [512, 512, [3, 5, 7]] \n", + " 11 -1 1 1182720 yolov5.models.common.C3 [512, 512, 1, False] \n", + " 12 -1 1 197376 yolov5.models.common.Conv [512, 384, 1, 1] \n", + " 13 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 14 [-1, 8] 1 0 yolov5.models.common.Concat [1] \n", + " 15 -1 1 813312 yolov5.models.common.C3 [768, 384, 1, False] \n", + " 16 -1 1 98816 yolov5.models.common.Conv [384, 256, 1, 1] \n", + " 17 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 18 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", + " 19 -1 1 361984 yolov5.models.common.C3 [512, 256, 1, False] \n", + " 20 -1 1 33024 yolov5.models.common.Conv [256, 128, 1, 1] \n", + " 21 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 22 [-1, 4] 1 0 yolov5.models.common.Concat [1] \n", + " 23 -1 1 90880 yolov5.models.common.C3 [256, 128, 1, False] \n", + " 24 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 2] \n", + " 25 [-1, 20] 1 0 yolov5.models.common.Concat [1] \n", + " 26 -1 1 296448 yolov5.models.common.C3 [256, 256, 1, False] \n", + " 27 -1 1 590336 yolov5.models.common.Conv [256, 256, 3, 2] \n", + " 28 [-1, 16] 1 0 yolov5.models.common.Concat [1] \n", + " 29 -1 1 715008 yolov5.models.common.C3 [512, 384, 1, False] \n", + " 30 -1 1 1327872 yolov5.models.common.Conv [384, 384, 3, 2] \n", + " 31 [-1, 12] 1 0 yolov5.models.common.Concat [1] \n", + " 32 -1 1 1313792 yolov5.models.common.C3 [768, 512, 1, False] \n", + " 33 [23, 26, 29, 32] 1 327420 yolov5.models.yolo.Detect [80, [[19, 27, 44, 40, 38, 94], [96, 68, 86, 152, 180, 137], [140, 301, 303, 264, 238, 542], [436, 615, 739, 380, 925, 792]], [128, 256, 384, 512]]\n", + "Model Summary: 368 layers, 12667836 parameters, 12667836 gradients, 17.4 GFLOPS\n", + "\n", + "Adding autoShape... \n" ] }, { @@ -564,31 +601,31 @@ "name": "stderr", "output_type": "stream", "text": [ - " 50%|█████ | 1/2 [00:00<00:00, 1.79it/s]" + " 50%|█████ | 1/2 [00:02<00:02, 2.89s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Number of slices: 18\n" + "Number of slices: 15\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:01<00:00, 1.84it/s]" + "100%|██████████| 2/2 [00:05<00:00, 2.57s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Model loaded in 0.3756110668182373 seconds.\n", - "Slicing performed in 0.006395101547241211 seconds.\n", - "Prediction performed in 0.9911887645721436 seconds.\n", - "Exporting performed in 0.030394792556762695 seconds.\n" + "Model loaded in 0.314194917678833 seconds.\n", + "Slicing performed in 0.018460512161254883 seconds.\n", + "Prediction performed in 5.031987190246582 seconds.\n", + "Exporting performed in 0.03364849090576172 seconds.\n" ] }, { @@ -601,10 +638,11 @@ ], "source": [ "predict(\n", - " model_name=model_name,\n", - " model_parameters=model_parameters,\n", + " model_type=model_type,\n", + " model_path=model_path,\n", + " model_device=model_device,\n", + " model_confidence_threshold=model_confidence_threshold,\n", " source=source_image_dir,\n", - " apply_sliced_prediction=apply_sliced_prediction,\n", " slice_height=slice_height,\n", " slice_width=slice_width,\n", " overlap_height_ratio=overlap_height_ratio,\n", diff --git a/sahi/model.py b/sahi/model.py index ebf1d1311..91a1d1728 100644 --- a/sahi/model.py +++ b/sahi/model.py @@ -15,7 +15,7 @@ def __init__( config_path: Optional[str] = None, device: Optional[str] = None, mask_threshold: float = 0.5, - prediction_score_threshold: float = 0.3, + confidence_threshold: float = 0.3, category_mapping: Optional[Dict] = None, category_remapping: Optional[Dict] = None, load_at_init: bool = True, @@ -32,8 +32,8 @@ def __init__( Torch device, "cpu" or "cuda" mask_threshold: float Value to threshold mask pixels, should be between 0 and 1 - prediction_score_threshold: float - All predictions with score < prediction_score_threshold will be discarded + confidence_threshold: float + All predictions with score < confidence_threshold will be discarded category_mapping: dict: str to str Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} category_remapping: dict: str to int @@ -46,7 +46,7 @@ def __init__( self.model = None self.device = device self.mask_threshold = mask_threshold - self.prediction_score_threshold = prediction_score_threshold + self.confidence_threshold = confidence_threshold self.category_mapping = category_mapping self.category_remapping = category_remapping self._original_predictions = None diff --git a/sahi/predict.py b/sahi/predict.py index a5c6624aa..db4e3ae41 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -26,6 +26,11 @@ save_pickle, ) +MODEL_TYPE_TO_MODEL_CLASS_NAME = { + "mmdet": "MmdetDetectionModel", + "yolov5": "Yolov5DetectionModel", +} + def get_prediction( image, @@ -82,7 +87,7 @@ def get_prediction( filtered_object_prediction_list = [ object_prediction for object_prediction in object_prediction_list - if object_prediction.score.value > detection_model.prediction_score_threshold + if object_prediction.score.value > detection_model.confidence_threshold ] # postprocess matching predictions if postprocess is not None: @@ -272,8 +277,13 @@ def get_sliced_prediction( def predict( - model_name: str = "MmdetDetectionModel", - model_parameters: Dict = None, + model_type: str = "mmdet", + model_path: str = None, + model_config_path: str = None, + model_confidence_threshold: float = 0.25, + model_device: str = None, + model_category_mapping: dict = None, + model_category_remapping: dict = None, source: str = None, no_standard_prediction: bool = False, no_sliced_prediction: bool = False, @@ -301,19 +311,20 @@ def predict( Performs prediction for all present images in given folder. Args: - model_name: str - Name of the implemented DetectionModel in model.py file. - model_parameter: a dict with fields: - model_path: str - Path for the instance segmentation model weight - config_path: str - Path for the mmdetection instance segmentation model config file - prediction_score_threshold: float - All predictions with score < prediction_score_threshold will be discarded. - device: str - Torch device, "cpu" or "cuda" - category_remapping: dict: str to int - Remap category ids after performing inference + model_type: str + mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'. + model_path: str + Path for the model weight + model_config_path: str + Path for the detection model config file + model_confidence_threshold: float + All predictions with score < model_confidence_threshold will be discarded. + model_device: str + Torch device, "cpu" or "cuda" + model_category_mapping: dict + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + model_category_remapping: dict: str to int + Remap category ids after performing inference source: str Folder directory that contains images or path of the image to be predicted. no_standard_prediction: bool @@ -398,14 +409,15 @@ def predict( # init model instance time_start = time.time() - DetectionModel = import_class(model_name) + model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type] + DetectionModel = import_class(model_class_name) detection_model = DetectionModel( - model_path=model_parameters["model_path"], - config_path=model_parameters.get("config_path", None), - prediction_score_threshold=model_parameters.get("prediction_score_threshold", 0.25), - device=model_parameters.get("device", None), - category_mapping=model_parameters.get("category_mapping", None), - category_remapping=model_parameters.get("category_remapping", None), + model_path=model_path, + config_path=model_config_path, + confidence_threshold=model_confidence_threshold, + device=model_device, + category_mapping=model_category_mapping, + category_remapping=model_category_remapping, load_at_init=False, ) detection_model.load_model() @@ -568,8 +580,13 @@ def predict( def predict_fiftyone( - model_name: str = "MmdetDetectionModel", - model_parameters: Dict = None, + model_type: str = "mmdet", + model_path: str = None, + model_config_path: str = None, + model_confidence_threshold: float = 0.25, + model_device: str = None, + model_category_mapping: dict = None, + model_category_remapping: dict = None, coco_json_path: str = None, coco_image_dir: str = None, no_standard_prediction: bool = False, @@ -588,19 +605,20 @@ def predict_fiftyone( Performs prediction for all present images in given folder. Args: - model_name: str - Name of the implemented DetectionModel in model.py file. - model_parameter: a dict with fields: - model_path: str - Path for the instance segmentation model weight - config_path: str - Path for the mmdetection instance segmentation model config file - prediction_score_threshold: float - All predictions with score < prediction_score_threshold will be discarded. - device: str - Torch device, "cpu" or "cuda" - category_remapping: dict: str to int - Remap category ids after performing inference + model_type: str + mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'. + model_path: str + Path for the model weight + model_config_path: str + Path for the detection model config file + model_confidence_threshold: float + All predictions with score < model_confidence_threshold will be discarded. + model_device: str + Torch device, "cpu" or "cuda" + model_category_mapping: dict + Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} + model_category_remapping: dict: str to int + Remap category ids after performing inference coco_json_path: str If coco file path is provided, detection results will be exported in coco json format. coco_image_dir: str @@ -651,14 +669,15 @@ def predict_fiftyone( # init model instance time_start = time.time() - DetectionModel = import_class(model_name) + model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type] + DetectionModel = import_class(model_class_name) detection_model = DetectionModel( - model_path=model_parameters["model_path"], - config_path=model_parameters.get("config_path", None), - prediction_score_threshold=model_parameters.get("prediction_score_threshold", 0.25), - device=model_parameters.get("device", None), - category_mapping=model_parameters.get("category_mapping", None), - category_remapping=model_parameters.get("category_remapping", None), + model_path=model_path, + config_path=model_config_path, + confidence_threshold=model_confidence_threshold, + device=model_device, + category_mapping=model_category_mapping, + category_remapping=model_category_remapping, load_at_init=False, ) detection_model.load_model() @@ -702,7 +721,7 @@ def predict_fiftyone( durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"] # Save predictions to dataset - sample[model_name] = fo.Detections(detections=prediction_result.to_fiftyone_detections()) + sample[model_type] = fo.Detections(detections=prediction_result.to_fiftyone_detections()) sample.save() # print prediction duration @@ -728,7 +747,7 @@ def predict_fiftyone( session.dataset = dataset # Evaluate the predictions results = dataset.evaluate_detections( - model_name, + model_type, gt_field="ground_truth", eval_key="eval", iou=postprocess_match_threshold, diff --git a/scripts/predict.py b/scripts/predict.py index 153defe85..956887618 100644 --- a/scripts/predict.py +++ b/scripts/predict.py @@ -78,22 +78,14 @@ opt = parser.parse_args() - model_type_to_model_name = { - "mmdet": "MmdetDetectionModel", - "yolov5": "Yolov5DetectionModel", - } - - model_parameters = { - "model_path": opt.model_path, - "config_path": opt.config_path, - "prediction_score_threshold": opt.conf_thresh, - "device": opt.device, - "category_mapping": opt.category_mapping, - "category_remapping": opt.category_remapping, - } predict( - model_name=model_type_to_model_name[opt.model_type], - model_parameters=model_parameters, + model_type=opt.model_type, + model_path=opt.model_path, + model_config_path=opt.config_path, + model_confidence_threshold=opt.conf_thresh, + model_device=opt.device, + model_category_mapping=opt.category_mapping, + model_category_remapping=opt.category_remapping, source=opt.source, project=opt.project, name=opt.name, diff --git a/scripts/predict_fiftyone.py b/scripts/predict_fiftyone.py index 5155cc9b3..a85abfe62 100644 --- a/scripts/predict_fiftyone.py +++ b/scripts/predict_fiftyone.py @@ -71,22 +71,14 @@ ) opt = parser.parse_args() - model_type_to_model_name = { - "mmdet": "MmdetDetectionModel", - "yolov5": "Yolov5DetectionModel", - } - - model_parameters = { - "model_path": opt.model_path, - "config_path": opt.config_path, - "prediction_score_threshold": opt.conf_thresh, - "device": opt.device, - "category_mapping": opt.category_mapping, - "category_remapping": opt.category_remapping, - } predict_fiftyone( - model_name=model_type_to_model_name[opt.model_type], - model_parameters=model_parameters, + model_type=opt.model_type, + model_path=opt.model_path, + model_config_path=opt.config_path, + model_confidence_threshold=opt.conf_thresh, + model_device=opt.device, + model_category_mapping=opt.category_mapping, + model_category_remapping=opt.category_remapping, coco_json_path=opt.coco_json_path, coco_image_dir=opt.coco_image_dir, no_standard_prediction=opt.no_standard_pred, diff --git a/tests/test_mmdetectionmodel.py b/tests/test_mmdetectionmodel.py index 8e595659a..b780fa406 100644 --- a/tests/test_mmdetectionmodel.py +++ b/tests/test_mmdetectionmodel.py @@ -22,7 +22,7 @@ def test_load_model(self): mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, - prediction_score_threshold=0.3, + confidence_threshold=0.3, device=None, category_remapping=None, load_at_init=True, @@ -39,7 +39,7 @@ def test_perform_inference_with_mask_output(self): mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, - prediction_score_threshold=0.5, + confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, @@ -77,7 +77,7 @@ def test_perform_inference_without_mask_output(self): mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH, config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH, - prediction_score_threshold=0.5, + confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, @@ -113,7 +113,7 @@ def test_convert_original_predictions_with_mask_output(self): mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, - prediction_score_threshold=0.5, + confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, @@ -160,7 +160,7 @@ def test_convert_original_predictions_without_mask_output(self): mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH, config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH, - prediction_score_threshold=0.5, + confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, @@ -203,7 +203,7 @@ def test_create_original_predictions_from_object_prediction_list_with_mask_outpu mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, - prediction_score_threshold=0.5, + confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, @@ -248,7 +248,7 @@ def test_create_original_predictions_from_object_prediction_list_without_mask_ou mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH, config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH, - prediction_score_threshold=0.5, + confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, diff --git a/tests/test_predict.py b/tests/test_predict.py index 9931e58da..caa1ac26d 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -35,7 +35,7 @@ def test_get_prediction_mmdet(self): mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, - prediction_score_threshold=0.3, + confidence_threshold=0.3, device=None, category_remapping=None, ) @@ -85,7 +85,7 @@ def test_get_prediction_yolov5(self): yolov5_detection_model = Yolov5DetectionModel( model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, - prediction_score_threshold=0.3, + confidence_threshold=0.3, device=None, category_remapping=None, load_at_init=False, @@ -134,7 +134,7 @@ def test_get_sliced_prediction_mmdet(self): mmdet_detection_model = MmdetDetectionModel( model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, - prediction_score_threshold=0.3, + confidence_threshold=0.3, device=None, category_remapping=None, load_at_init=False, @@ -201,7 +201,7 @@ def test_get_sliced_prediction_yolov5(self): yolov5_detection_model = Yolov5DetectionModel( model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, - prediction_score_threshold=0.3, + confidence_threshold=0.3, device=None, category_remapping=None, load_at_init=False, @@ -268,14 +268,6 @@ def test_coco_json_prediction(self): # init model download_mmdet_cascade_mask_rcnn_model() - model_parameters = { - "model_path": MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, - "config_path": MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, - "prediction_score_threshold": 0.4, - "device": None, # cpu or cuda - "category_mapping": None, - "category_remapping": None, # {"0": 1, "1": 2, "2": 3} - } postprocess_type = "UNIONMERGE" match_metric = "IOS" match_threshold = 0.5 @@ -290,8 +282,13 @@ def test_coco_json_prediction(self): if os.path.isdir(project_dir): shutil.rmtree(project_dir) predict( - model_name="MmdetDetectionModel", - model_parameters=model_parameters, + model_type="mmdet", + model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH, + model_config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH, + model_confidence_threshold=0.4, + model_device=None, + model_category_mapping=None, + model_category_remapping=None, source=source, no_sliced_prediction=False, no_standard_prediction=True, @@ -315,14 +312,6 @@ def test_coco_json_prediction(self): # init model download_yolov5s6_model() - model_parameters = { - "model_path": Yolov5TestConstants.YOLOV5S6_MODEL_PATH, - "prediction_score_threshold": 0.4, - "device": None, # cpu or cuda - "category_mapping": None, - "category_remapping": None, # {"0": 1, "1": 2, "2": 3} - } - # prepare paths coco_file_path = "tests/data/coco_utils/terrain_all_coco.json" source = "tests/data/coco_utils/" @@ -332,8 +321,13 @@ def test_coco_json_prediction(self): if os.path.isdir(project_dir): shutil.rmtree(project_dir) predict( - model_name="Yolov5DetectionModel", - model_parameters=model_parameters, + model_type="yolov5", + model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, + model_config_path=None, + model_confidence_threshold=0.4, + model_device=None, + model_category_mapping=None, + model_category_remapping=None, source=source, no_sliced_prediction=False, no_standard_prediction=True, diff --git a/tests/test_yolov5model.py b/tests/test_yolov5model.py index daafd8c1e..e8279f288 100644 --- a/tests/test_yolov5model.py +++ b/tests/test_yolov5model.py @@ -20,7 +20,7 @@ def test_load_model(self): yolov5_detection_model = Yolov5DetectionModel( model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, - prediction_score_threshold=0.3, + confidence_threshold=0.3, device=None, category_remapping=None, load_at_init=True, @@ -36,7 +36,7 @@ def test_perform_inference(self): yolov5_detection_model = Yolov5DetectionModel( model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, - prediction_score_threshold=0.5, + confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True, @@ -74,7 +74,7 @@ def test_convert_original_predictions(self): yolov5_detection_model = Yolov5DetectionModel( model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH, - prediction_score_threshold=0.5, + confidence_threshold=0.5, device=None, category_remapping=None, load_at_init=True,