From 007dc9250c074d4b55d017211c59dc2880f1d1c7 Mon Sep 17 00:00:00 2001
From: fcakyon <34196005+fcakyon@users.noreply.github.com>
Date: Thu, 29 Jul 2021 00:19:55 +0300
Subject: [PATCH] add black and isort check to ci + reformat codebase (#179)
* add black check to ci workflow
* update dev package versions
* add isort check for ci
* reformat with isort
* reformat with upodate black and isort
* reformat config files with black
* update readme with new contributing guidelines
* update project toml with isort config
---
.github/workflows/ci.yml | 6 +-
README.md | 16 ++
pyproject.toml | 6 +-
sahi/annotation.py | 74 ++++++--
sahi/model.py | 38 ++--
sahi/postprocess/combine.py | 36 +++-
sahi/postprocess/legacy/match.py | 4 +-
sahi/postprocess/legacy/merge.py | 30 ++--
sahi/postprocess/legacy/ops.py | 5 +-
sahi/predict.py | 82 +++++----
sahi/prediction.py | 6 +-
sahi/slicing.py | 15 +-
sahi/utils/coco.py | 150 +++++++++++++---
sahi/utils/cv.py | 43 ++++-
sahi/utils/fiftyone.py | 29 ++-
sahi/utils/file.py | 13 +-
sahi/utils/mmdet.py | 22 ++-
sahi/utils/mot.py | 12 +-
sahi/utils/yolov5.py | 3 +-
scripts/coco2yolov5.py | 19 +-
scripts/coco_error_analysis.py | 19 +-
scripts/coco_evaluation.py | 8 +-
scripts/predict.py | 24 ++-
scripts/predict_fiftyone.py | 24 ++-
scripts/slice_coco.py | 15 +-
setup.py | 2 +-
.../cascade_mask_rcnn_r50_fpn.py | 52 +++++-
.../cascade_mask_rcnn_r50_fpn_v280.py | 74 ++++++--
tests/test_cocoutils.py | 167 ++++++++++++------
tests/test_filter.py | 5 +-
tests/test_mmdetectionmodel.py | 31 ++--
tests/test_predict.py | 37 ++--
tests/test_shapelyutils.py | 78 +++++---
tests/test_slicing.py | 16 +-
tests/test_yolov5model.py | 14 +-
35 files changed, 858 insertions(+), 317 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index f22079520..a05a912b4 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -71,11 +71,13 @@ jobs:
run: >
pip install mmcv-full==1.3.7 mmdet==2.13.0 yolov5==5.0.6 norfair==0.3.0
- - name: Lint with flake8
+ - name: Lint with flake8, black and isort
run: |
- pip install flake8
+ pip install -e .[dev]
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+ black . --check --config pyproject.toml
+ isort -c .
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
diff --git a/README.md b/README.md
index c181e535a..e3dd3cd82 100644
--- a/README.md
+++ b/README.md
@@ -405,6 +405,22 @@ mot_video.export(export_dir="mot_gt", type="gt")
All you need to do is, creating a new class in [model.py](sahi/model.py) that implements [DetectionModel class](https://github.com/obss/sahi/blob/651f8e6cdb20467815748764bb198dd50241ab2b/sahi/model.py#L10). You can take the [MMDetection wrapper](https://github.com/obss/sahi/blob/651f8e6cdb20467815748764bb198dd50241ab2b/sahi/model.py#L164) or [YOLOv5 wrapper](https://github.com/obss/sahi/blob/ffa168fc38b75a002a0117f1fdde9470e1a9ce8c/sahi/model.py#L363) as a reference.
+Before opening a PR:
+
+- Install required development packages:
+
+```bash
+pip install -U -e .[dev]
+```
+
+- Reformat with black and isort:
+
+```bash
+black . --config pyproject.toml
+isort .
+```
+
+
##
Contributers
diff --git a/pyproject.toml b/pyproject.toml
index e34796ec5..b8bffa81a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,2 +1,6 @@
[tool.black]
-line-length = 120
\ No newline at end of file
+line-length = 120
+
+[tool.isort]
+line_length = 120
+profile = "black"
\ No newline at end of file
diff --git a/sahi/annotation.py b/sahi/annotation.py
index f1780a5c4..a0d334824 100644
--- a/sahi/annotation.py
+++ b/sahi/annotation.py
@@ -2,7 +2,7 @@
# Code written by Fatih C Akyon, 2020.
import copy
-from typing import List, Optional, Dict
+from typing import Dict, List, Optional
import numpy as np
@@ -111,7 +111,11 @@ def __repr__(self):
class Mask:
@classmethod
def from_float_mask(
- cls, mask, full_shape=None, mask_threshold: float = 0.5, shift_amount: list = [0, 0],
+ cls,
+ mask,
+ full_shape=None,
+ mask_threshold: float = 0.5,
+ shift_amount: list = [0, 0],
):
"""
Args:
@@ -126,11 +130,18 @@ def from_float_mask(
Size of the full image after shifting, should be in the form of [height, width]
"""
bool_mask = mask > mask_threshold
- return cls(bool_mask=bool_mask, shift_amount=shift_amount, full_shape=full_shape,)
+ return cls(
+ bool_mask=bool_mask,
+ shift_amount=shift_amount,
+ full_shape=full_shape,
+ )
@classmethod
def from_coco_segmentation(
- cls, segmentation, full_shape=None, shift_amount: list = [0, 0],
+ cls,
+ segmentation,
+ full_shape=None,
+ shift_amount: list = [0, 0],
):
"""
Init Mask from coco segmentation representation.
@@ -152,10 +163,17 @@ def from_coco_segmentation(
assert full_shape is not None, "full_shape must be provided"
bool_mask = get_bool_mask_from_coco_segmentation(segmentation, height=full_shape[0], width=full_shape[1])
- return cls(bool_mask=bool_mask, shift_amount=shift_amount, full_shape=full_shape,)
+ return cls(
+ bool_mask=bool_mask,
+ shift_amount=shift_amount,
+ full_shape=full_shape,
+ )
def __init__(
- self, bool_mask=None, full_shape=None, shift_amount: list = [0, 0],
+ self,
+ bool_mask=None,
+ full_shape=None,
+ shift_amount: list = [0, 0],
):
"""
Args:
@@ -216,7 +234,14 @@ def get_shifted_mask(self):
# Confirm full_shape is specified
assert (self.full_shape_height is not None) and (self.full_shape_width is not None), "full_shape is None"
# init full mask
- mask_fullsized = np.full((self.full_shape_height, self.full_shape_width,), 0, dtype="float32",)
+ mask_fullsized = np.full(
+ (
+ self.full_shape_height,
+ self.full_shape_width,
+ ),
+ 0,
+ dtype="float32",
+ )
# arrange starting ending indexes
starting_pixel = [self.shift_x, self.shift_y]
@@ -230,7 +255,11 @@ def get_shifted_mask(self):
: ending_pixel[1] - starting_pixel[1], : ending_pixel[0] - starting_pixel[0]
]
- return Mask(mask_fullsized, shift_amount=[0, 0], full_shape=self.full_shape,)
+ return Mask(
+ mask_fullsized,
+ shift_amount=[0, 0],
+ full_shape=self.full_shape,
+ )
def to_coco_segmentation(self):
"""
@@ -441,7 +470,10 @@ def from_shapely_annotation(
@classmethod
def from_imantics_annotation(
- cls, annotation, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
+ cls,
+ annotation,
+ shift_amount: Optional[List[int]] = [0, 0],
+ full_shape: Optional[List[int]] = None,
):
"""
Creates ObjectAnnotation from imantics.annotation.Annotation
@@ -495,11 +527,18 @@ def __init__(
self.mask = None
self.bbox = BoundingBox(bbox, shift_amount)
else:
- self.mask = Mask(bool_mask=bool_mask, shift_amount=shift_amount, full_shape=full_shape,)
+ self.mask = Mask(
+ bool_mask=bool_mask,
+ shift_amount=shift_amount,
+ full_shape=full_shape,
+ )
bbox = get_bbox_from_bool_mask(bool_mask)
self.bbox = BoundingBox(bbox, shift_amount)
category_name = category_name if category_name else str(category_id)
- self.category = Category(id=category_id, name=category_name,)
+ self.category = Category(
+ id=category_id,
+ name=category_name,
+ )
self.merged = None
@@ -515,7 +554,9 @@ def to_coco_annotation(self):
)
else:
coco_annotation = CocoAnnotation.from_coco_bbox(
- bbox=self.bbox.to_coco_bbox(), category_id=self.category.id, category_name=self.category.name,
+ bbox=self.bbox.to_coco_bbox(),
+ category_id=self.category.id,
+ category_name=self.category.name,
)
return coco_annotation
@@ -532,7 +573,10 @@ def to_coco_prediction(self):
)
else:
coco_prediction = CocoPrediction.from_coco_bbox(
- bbox=self.bbox.to_coco_bbox(), category_id=self.category.id, category_name=self.category.name, score=1,
+ bbox=self.bbox.to_coco_bbox(),
+ category_id=self.category.id,
+ category_name=self.category.name,
+ score=1,
)
return coco_prediction
@@ -545,7 +589,9 @@ def to_shapely_annotation(self):
segmentation=self.mask.to_coco_segmentation(),
)
else:
- shapely_annotation = ShapelyAnnotation.from_coco_bbox(bbox=self.bbox.to_coco_bbox(),)
+ shapely_annotation = ShapelyAnnotation.from_coco_bbox(
+ bbox=self.bbox.to_coco_bbox(),
+ )
return shapely_annotation
def to_imantics_annotation(self):
diff --git a/sahi/model.py b/sahi/model.py
index b793103d4..20d7787df 100644
--- a/sahi/model.py
+++ b/sahi/model.py
@@ -1,11 +1,12 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.
+from typing import Dict, List, Optional, Union
+
import numpy as np
from sahi.prediction import ObjectPrediction
from sahi.utils.torch import cuda_is_available, empty_cuda_cache
-from typing import List, Dict, Optional, Union
class DetectionModel:
@@ -89,7 +90,9 @@ def perform_inference(self, image: np.ndarray, image_size: int = None):
NotImplementedError()
def _create_object_prediction_list_from_original_predictions(
- self, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
+ self,
+ shift_amount: Optional[List[int]] = [0, 0],
+ full_shape: Optional[List[int]] = None,
):
"""
This function should be implemented in a way that self._original_predictions should
@@ -117,7 +120,9 @@ def _apply_category_remapping(self):
object_prediction.category.id = new_category_id_int
def convert_original_predictions(
- self, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
+ self,
+ shift_amount: Optional[List[int]] = [0, 0],
+ full_shape: Optional[List[int]] = None,
):
"""
Converts original predictions of the detection model to a list of
@@ -130,7 +135,8 @@ def convert_original_predictions(
Size of the full image after shifting, should be in the form of [height, width]
"""
self._create_object_prediction_list_from_original_predictions(
- shift_amount=shift_amount, full_shape=full_shape,
+ shift_amount=shift_amount,
+ full_shape=full_shape,
)
if self.category_remapping:
self._apply_category_remapping()
@@ -143,7 +149,9 @@ def object_prediction_list(self):
def original_predictions(self):
return self._original_predictions
- def _create_predictions_from_object_prediction_list(object_prediction_list: List[ObjectPrediction],):
+ def _create_predictions_from_object_prediction_list(
+ object_prediction_list: List[ObjectPrediction],
+ ):
"""
This function should be implemented in a way that it converts a list of
prediction.ObjectPrediction instance to detection model's original prediction format.
@@ -172,7 +180,11 @@ def load_model(self):
from mmdet.apis import init_detector
# set model
- model = init_detector(config=self.config_path, checkpoint=self.model_path, device=self.device,)
+ model = init_detector(
+ config=self.config_path,
+ checkpoint=self.model_path,
+ device=self.device,
+ )
self.model = model
# set category_mapping
@@ -239,7 +251,9 @@ def category_names(self):
return self.model.CLASSES
def _create_object_prediction_list_from_original_predictions(
- self, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
+ self,
+ shift_amount: Optional[List[int]] = [0, 0],
+ full_shape: Optional[List[int]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
@@ -294,7 +308,8 @@ def _create_object_prediction_list_from_original_predictions(
self._object_prediction_list = object_prediction_list
def _create_original_predictions_from_object_prediction_list(
- self, object_prediction_list: List[ObjectPrediction],
+ self,
+ object_prediction_list: List[ObjectPrediction],
):
"""
Converts a list of prediction.ObjectPrediction instance to detection model's original prediction format.
@@ -412,7 +427,9 @@ def category_names(self):
return self.model.names
def _create_object_prediction_list_from_original_predictions(
- self, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
+ self,
+ shift_amount: Optional[List[int]] = [0, 0],
+ full_shape: Optional[List[int]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
@@ -456,7 +473,8 @@ def _create_object_prediction_list_from_original_predictions(
self._object_prediction_list = object_prediction_list
def _create_original_predictions_from_object_prediction_list(
- self, object_prediction_list: List[ObjectPrediction],
+ self,
+ object_prediction_list: List[ObjectPrediction],
):
"""
Converts a list of prediction.ObjectPrediction instance to detection model's original
diff --git a/sahi/postprocess/combine.py b/sahi/postprocess/combine.py
index 40c79e443..332cb329b 100644
--- a/sahi/postprocess/combine.py
+++ b/sahi/postprocess/combine.py
@@ -1,11 +1,13 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2021.
+import copy
from typing import List, Union
-from sahi.prediction import ObjectPrediction
-from sahi.annotation import Mask, BoundingBox, Category
+
import numpy as np
-import copy
+
+from sahi.annotation import BoundingBox, Category, Mask
+from sahi.prediction import ObjectPrediction
def calculate_area(box: Union[List[int], np.ndarray]) -> float:
@@ -45,7 +47,10 @@ class PostprocessPredictions:
"""Combines predictions using NMS elimination utilizing provided match metric ('IOU' or 'IOS')"""
def __init__(
- self, match_threshold: float = 0.5, match_metric: str = "IOU", class_agnostic: bool = True,
+ self,
+ match_threshold: float = 0.5,
+ match_metric: str = "IOU",
+ class_agnostic: bool = True,
):
self.match_threshold = match_threshold
self.class_agnostic = class_agnostic
@@ -97,7 +102,8 @@ def __call__(self):
class NMSPostprocess(PostprocessPredictions):
def __call__(
- self, object_predictions: List[ObjectPrediction],
+ self,
+ object_predictions: List[ObjectPrediction],
):
source_object_predictions: List[ObjectPrediction] = copy.deepcopy(object_predictions)
selected_object_predictions: List[ObjectPrediction] = []
@@ -120,7 +126,8 @@ def __call__(
class UnionMergePostprocess(PostprocessPredictions):
def __call__(
- self, object_predictions: List[ObjectPrediction],
+ self,
+ object_predictions: List[ObjectPrediction],
):
source_object_predictions: List[ObjectPrediction] = copy.deepcopy(object_predictions)
selected_object_predictions: List[ObjectPrediction] = []
@@ -145,7 +152,11 @@ def __call__(
selected_object_predictions.append(selected_object_prediction)
return selected_object_predictions
- def _merge_object_prediction_pair(self, pred1: ObjectPrediction, pred2: ObjectPrediction,) -> ObjectPrediction:
+ def _merge_object_prediction_pair(
+ self,
+ pred1: ObjectPrediction,
+ pred2: ObjectPrediction,
+ ) -> ObjectPrediction:
shift_amount = pred1.bbox.shift_amount
merged_bbox: BoundingBox = self._get_merged_bbox(pred1, pred2)
merged_score: float = self._get_merged_score(pred1, pred2)
@@ -182,7 +193,10 @@ def _get_merged_bbox(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Boundi
return bbox
@staticmethod
- def _get_merged_score(pred1: ObjectPrediction, pred2: ObjectPrediction,) -> float:
+ def _get_merged_score(
+ pred1: ObjectPrediction,
+ pred2: ObjectPrediction,
+ ) -> float:
scores: List[float] = [pred.score.value for pred in (pred1, pred2)]
return max(scores)
@@ -191,4 +205,8 @@ def _get_merged_mask(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Mask:
mask1 = pred1.mask
mask2 = pred2.mask
union_mask = np.logical_or(mask1.bool_mask, mask2.bool_mask)
- return Mask(bool_mask=union_mask, full_shape=mask1.full_shape, shift_amount=mask1.shift_amount,)
+ return Mask(
+ bool_mask=union_mask,
+ full_shape=mask1.full_shape,
+ shift_amount=mask1.shift_amount,
+ )
diff --git a/sahi/postprocess/legacy/match.py b/sahi/postprocess/legacy/match.py
index 921af3cf8..1b5717a7c 100644
--- a/sahi/postprocess/legacy/match.py
+++ b/sahi/postprocess/legacy/match.py
@@ -46,7 +46,9 @@ class PredictionMatcher:
"""
def __init__(
- self, threshold: float = 0.5, scorer: Callable[[BoxArray, BoxArray], float] = box_ios,
+ self,
+ threshold: float = 0.5,
+ scorer: Callable[[BoxArray, BoxArray], float] = box_ios,
):
self._threshold = threshold
self._scorer = scorer
diff --git a/sahi/postprocess/legacy/merge.py b/sahi/postprocess/legacy/merge.py
index f32632ae8..c861904dd 100644
--- a/sahi/postprocess/legacy/merge.py
+++ b/sahi/postprocess/legacy/merge.py
@@ -6,15 +6,10 @@
from typing import Callable, List
import numpy as np
+
from sahi.annotation import Mask
from sahi.postprocess.legacy.match import PredictionList, PredictionMatcher
-from sahi.postprocess.legacy.ops import (
- BoxArray,
- box_union,
- calculate_area,
- extract_box,
- have_same_class,
-)
+from sahi.postprocess.legacy.ops import BoxArray, box_union, calculate_area, extract_box, have_same_class
from sahi.prediction import ObjectPrediction
@@ -77,7 +72,8 @@ def merge_batch(
"""
if merge_type not in ["merge", "ensemble"]:
raise ValueError(
- 'Unknown merge type. Supported types are ["merge", "ensemble"], got type: ', merge_type,
+ 'Unknown merge type. Supported types are ["merge", "ensemble"], got type: ',
+ merge_type,
)
unions = matcher.find_matched_predictions(predictions, ignore_class_label)
@@ -115,7 +111,11 @@ def _store_merging_info(count, prediction, merge_type):
else:
prediction.merged = False
- def _merge_pair(self, pred1: ObjectPrediction, pred2: ObjectPrediction,) -> ObjectPrediction:
+ def _merge_pair(
+ self,
+ pred1: ObjectPrediction,
+ pred2: ObjectPrediction,
+ ) -> ObjectPrediction:
box1 = extract_box(pred1)
box2 = extract_box(pred2)
merged_box = list(self._merge_box(box1, box2))
@@ -154,7 +154,11 @@ def _assert_equal_labels(pred1: ObjectPrediction, pred2: ObjectPrediction):
def _merge_box(self, box1: BoxArray, box2: BoxArray) -> BoxArray:
return self._box_merger(box1, box2)
- def _merge_score(self, pred1: ObjectPrediction, pred2: ObjectPrediction,) -> float:
+ def _merge_score(
+ self,
+ pred1: ObjectPrediction,
+ pred2: ObjectPrediction,
+ ) -> float:
scores = [pred.score.value for pred in (pred1, pred2)]
policy = self._score_merging_method
if policy == ScoreMergingPolicy.SMALLER_SCORE:
@@ -176,7 +180,11 @@ def _merge_mask(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Mask:
mask1 = pred1.mask
mask2 = pred2.mask
union_mask = np.logical_or(mask1.bool_mask, mask2.bool_mask)
- return Mask(bool_mask=union_mask, full_shape=mask1.full_shape, shift_amount=mask1.shift_amount,)
+ return Mask(
+ bool_mask=union_mask,
+ full_shape=mask1.full_shape,
+ shift_amount=mask1.shift_amount,
+ )
def _validate_box_merger(self, box_merger: Callable):
if box_merger.__name__ not in self.BOX_MERGERS:
diff --git a/sahi/postprocess/legacy/ops.py b/sahi/postprocess/legacy/ops.py
index 48e3e9f1c..b64a6c3c9 100644
--- a/sahi/postprocess/legacy/ops.py
+++ b/sahi/postprocess/legacy/ops.py
@@ -3,6 +3,7 @@
# Modified by Fatih C Akyon, 2020.
import numpy as np
+
from sahi.prediction import ObjectPrediction
BoxArray = np.ndarray
@@ -28,7 +29,7 @@ def have_same_class(pred1: ObjectPrediction, pred2: ObjectPrediction) -> bool:
def box_iou(box1: BoxArray, box2: BoxArray) -> float:
- """ Returns the ratio of intersection area to the union """
+ """Returns the ratio of intersection area to the union"""
area1 = calculate_area(box1)
area2 = calculate_area(box2)
intersect = intersection_area(box1, box2)
@@ -36,7 +37,7 @@ def box_iou(box1: BoxArray, box2: BoxArray) -> float:
def box_ios(box1: BoxArray, box2: BoxArray) -> float:
- """ Returns the ratio of intersection area to the smaller box's area """
+ """Returns the ratio of intersection area to the smaller box's area"""
area1 = calculate_area(box1)
area2 = calculate_area(box2)
intersect = intersection_area(box1, box2)
diff --git a/sahi/predict.py b/sahi/predict.py
index ffe9cbf70..d9e2634ae 100644
--- a/sahi/predict.py
+++ b/sahi/predict.py
@@ -3,28 +3,17 @@
import os
import time
-from typing import Dict, Optional, List
+from typing import Dict, List, Optional
import numpy as np
from tqdm import tqdm
+from sahi.postprocess.combine import NMSPostprocess, PostprocessPredictions, UnionMergePostprocess
from sahi.prediction import ObjectPrediction, PredictionResult
-from sahi.postprocess.combine import UnionMergePostprocess, PostprocessPredictions, NMSPostprocess
from sahi.slicing import slice_image
from sahi.utils.coco import Coco, CocoImage
-from sahi.utils.cv import (
- crop_object_predictions,
- read_image_as_pil,
- visualize_object_predictions,
-)
-from sahi.utils.file import (
- Path,
- import_class,
- increment_path,
- list_files,
- save_json,
- save_pickle,
-)
+from sahi.utils.cv import crop_object_predictions, read_image_as_pil, visualize_object_predictions
+from sahi.utils.file import Path, import_class, increment_path, list_files, save_json, save_pickle
MODEL_TYPE_TO_MODEL_CLASS_NAME = {
"mmdet": "MmdetDetectionModel",
@@ -79,7 +68,8 @@ def get_prediction(
time_start = time.time()
# works only with 1 batch
detection_model.convert_original_predictions(
- shift_amount=shift_amount, full_shape=full_shape,
+ shift_amount=shift_amount,
+ full_shape=full_shape,
)
object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list
# filter out predictions with lower score
@@ -102,7 +92,9 @@ def get_prediction(
if verbose == 1:
print(
- "Prediction performed in", durations_in_seconds["prediction"], "seconds.",
+ "Prediction performed in",
+ durations_in_seconds["prediction"],
+ "seconds.",
)
return PredictionResult(
@@ -224,7 +216,10 @@ def get_sliced_prediction(
detection_model=detection_model,
image_size=image_size,
shift_amount=shift_amount_list[0],
- full_shape=[slice_image_result.original_image_height, slice_image_result.original_image_width,],
+ full_shape=[
+ slice_image_result.original_image_height,
+ slice_image_result.original_image_width,
+ ],
)
object_prediction_list.extend(prediction_result.object_prediction_list)
if num_slices > 1 and perform_standard_pred:
@@ -252,10 +247,14 @@ def get_sliced_prediction(
if verbose == 2:
print(
- "Slicing performed in", durations_in_seconds["slice"], "seconds.",
+ "Slicing performed in",
+ durations_in_seconds["slice"],
+ "seconds.",
)
print(
- "Prediction performed in", durations_in_seconds["prediction"], "seconds.",
+ "Prediction performed in",
+ durations_in_seconds["prediction"],
+ "seconds.",
)
# merge matching predictions
@@ -378,7 +377,11 @@ def predict(
coco_json = []
elif os.path.isdir(source):
time_start = time.time()
- image_path_list = list_files(directory=source, contains=[".jpg", ".jpeg", ".png"], verbose=verbose,)
+ image_path_list = list_files(
+ directory=source,
+ contains=[".jpg", ".jpeg", ".png"],
+ verbose=verbose,
+ )
time_end = time.time() - time_start
durations_in_seconds["list_files"] = time_end
else:
@@ -543,17 +546,25 @@ def predict(
# print prediction duration
if verbose == 1:
print(
- "Model loaded in", durations_in_seconds["model_load"], "seconds.",
+ "Model loaded in",
+ durations_in_seconds["model_load"],
+ "seconds.",
)
print(
- "Slicing performed in", durations_in_seconds["slice"], "seconds.",
+ "Slicing performed in",
+ durations_in_seconds["slice"],
+ "seconds.",
)
print(
- "Prediction performed in", durations_in_seconds["prediction"], "seconds.",
+ "Prediction performed in",
+ durations_in_seconds["prediction"],
+ "seconds.",
)
if export_visual:
print(
- "Exporting performed in", durations_in_seconds["export_files"], "seconds.",
+ "Exporting performed in",
+ durations_in_seconds["export_files"],
+ "seconds.",
)
@@ -632,9 +643,10 @@ def predict_fiftyone(
0: no print
1: print slice/prediction durations, number of slices, model loading/file exporting durations
"""
- from sahi.utils.fiftyone import create_fiftyone_dataset_from_coco_file
import fiftyone as fo
+ from sahi.utils.fiftyone import create_fiftyone_dataset_from_coco_file
+
# assert prediction type
assert (
no_standard_prediction and no_sliced_prediction
@@ -705,13 +717,19 @@ def predict_fiftyone(
# print prediction duration
if verbose == 1:
print(
- "Model loaded in", durations_in_seconds["model_load"], "seconds.",
+ "Model loaded in",
+ durations_in_seconds["model_load"],
+ "seconds.",
)
print(
- "Slicing performed in", durations_in_seconds["slice"], "seconds.",
+ "Slicing performed in",
+ durations_in_seconds["slice"],
+ "seconds.",
)
print(
- "Prediction performed in", durations_in_seconds["prediction"], "seconds.",
+ "Prediction performed in",
+ durations_in_seconds["prediction"],
+ "seconds.",
)
# visualize results
@@ -719,7 +737,11 @@ def predict_fiftyone(
session.dataset = dataset
# Evaluate the predictions
results = dataset.evaluate_detections(
- model_type, gt_field="ground_truth", eval_key="eval", iou=postprocess_match_threshold, compute_mAP=True,
+ model_type,
+ gt_field="ground_truth",
+ eval_key="eval",
+ iou=postprocess_match_threshold,
+ compute_mAP=True,
)
# Get the 10 most common classes in the dataset
counts = dataset.count_values("ground_truth.detections.label")
diff --git a/sahi/prediction.py b/sahi/prediction.py
index af501f046..6805caadc 100644
--- a/sahi/prediction.py
+++ b/sahi/prediction.py
@@ -2,15 +2,15 @@
# Code written by Fatih C Akyon, 2020.
import copy
-from sahi.utils.coco import CocoPrediction, CocoAnnotation
-from typing import List, Optional, Union, Dict
+from typing import Dict, List, Optional, Union
import numpy as np
from PIL import Image
from sahi.annotation import ObjectAnnotation
+from sahi.utils.coco import CocoAnnotation, CocoPrediction
+from sahi.utils.cv import read_image_as_pil, visualize_object_predictions
from sahi.utils.file import create_dir
-from sahi.utils.cv import visualize_object_predictions, read_image_as_pil
class PredictionScore:
diff --git a/sahi/slicing.py b/sahi/slicing.py
index 41e666cfc..103fde34f 100644
--- a/sahi/slicing.py
+++ b/sahi/slicing.py
@@ -318,12 +318,19 @@ def slice_image(
# create sliced image and append to sliced_image_result
sliced_image = SlicedImage(
- image=np.asarray(image_pil_slice), coco_image=coco_image, starting_pixel=[slice_bbox[0], slice_bbox[1]],
+ image=np.asarray(image_pil_slice),
+ coco_image=coco_image,
+ starting_pixel=[slice_bbox[0], slice_bbox[1]],
)
sliced_image_result.add_sliced_image(sliced_image)
verboseprint(
- "Num slices:", n_ims, "slice_height", slice_height, "slice_width", slice_width,
+ "Num slices:",
+ n_ims,
+ "slice_height",
+ slice_height,
+ "slice_width",
+ slice_width,
)
verboseprint("Time to slice", image, time.time() - t0, "seconds")
@@ -409,7 +416,9 @@ def slice_coco(
# create and save coco dict
coco_dict = create_coco_dict(
- sliced_coco_images, coco_dict["categories"], ignore_negative_samples=ignore_negative_samples,
+ sliced_coco_images,
+ coco_dict["categories"],
+ ignore_negative_samples=ignore_negative_samples,
)
save_path = ""
if output_coco_annotation_file_name and output_dir:
diff --git a/sahi/utils/coco.py b/sahi/utils/coco.py
index 505e306ee..f4c7244c7 100644
--- a/sahi/utils/coco.py
+++ b/sahi/utils/coco.py
@@ -11,9 +11,10 @@
from typing import Dict, List, Optional, Set, Union
import numpy as np
+from tqdm import tqdm
+
from sahi.utils.file import load_json, save_json
from sahi.utils.shapely import ShapelyAnnotation, box, get_shapely_multipolygon
-from tqdm import tqdm
class CocoCategory:
@@ -35,7 +36,11 @@ def from_coco_category(cls, category):
category: Dict
{"supercategory": "person", "id": 1, "name": "person"},
"""
- return cls(id=category["id"], name=category["name"], supercategory=category["supercategory"],)
+ return cls(
+ id=category["id"],
+ name=category["name"],
+ supercategory=category["supercategory"],
+ )
@property
def json(self):
@@ -72,7 +77,12 @@ def from_coco_segmentation(cls, segmentation, category_id, category_name, iscrow
iscrowd: int
0 or 1
"""
- return cls(segmentation=segmentation, category_id=category_id, category_name=category_name, iscrowd=iscrowd,)
+ return cls(
+ segmentation=segmentation,
+ category_id=category_id,
+ category_name=category_name,
+ iscrowd=iscrowd,
+ )
@classmethod
def from_coco_bbox(cls, bbox, category_id, category_name, iscrowd=0):
@@ -89,7 +99,12 @@ def from_coco_bbox(cls, bbox, category_id, category_name, iscrowd=0):
iscrowd: int
0 or 1
"""
- return cls(bbox=bbox, category_id=category_id, category_name=category_name, iscrowd=iscrowd,)
+ return cls(
+ bbox=bbox,
+ category_id=category_id,
+ category_name=category_name,
+ iscrowd=iscrowd,
+ )
@classmethod
def from_coco_annotation_dict(cls, annotation_dict: Dict, category_name: Optional[str] = None):
@@ -111,12 +126,18 @@ def from_coco_annotation_dict(cls, annotation_dict: Dict, category_name: Optiona
)
else:
return cls(
- bbox=annotation_dict["bbox"], category_id=annotation_dict["category_id"], category_name=category_name,
+ bbox=annotation_dict["bbox"],
+ category_id=annotation_dict["category_id"],
+ category_name=category_name,
)
@classmethod
def from_shapely_annotation(
- cls, shapely_annotation: ShapelyAnnotation, category_id: int, category_name: str, iscrowd: int,
+ cls,
+ shapely_annotation: ShapelyAnnotation,
+ category_id: int,
+ category_name: str,
+ iscrowd: int,
):
"""
Creates CocoAnnotation object from ShapelyAnnotation object.
@@ -127,13 +148,24 @@ def from_shapely_annotation(
category_name (str): Category name of the annotation
iscrowd (int): 0 or 1
"""
- coco_annotation = cls(bbox=[0, 0, 0, 0], category_id=category_id, category_name=category_name, iscrowd=iscrowd,)
+ coco_annotation = cls(
+ bbox=[0, 0, 0, 0],
+ category_id=category_id,
+ category_name=category_name,
+ iscrowd=iscrowd,
+ )
coco_annotation._segmentation = shapely_annotation.to_coco_segmentation()
coco_annotation._shapely_annotation = shapely_annotation
return coco_annotation
def __init__(
- self, segmentation=None, bbox=None, category_id=None, category_name=None, image_id=None, iscrowd=0,
+ self,
+ segmentation=None,
+ bbox=None,
+ category_id=None,
+ category_name=None,
+ image_id=None,
+ iscrowd=0,
):
"""
Creates coco annotation object using bbox or segmentation
@@ -360,7 +392,14 @@ def from_coco_annotation_dict(cls, category_name, annotation_dict, score, image_
)
def __init__(
- self, segmentation=None, bbox=None, category_id=None, category_name=None, image_id=None, score=None, iscrowd=0,
+ self,
+ segmentation=None,
+ bbox=None,
+ category_id=None,
+ category_name=None,
+ image_id=None,
+ score=None,
+ iscrowd=0,
):
"""
@@ -425,7 +464,14 @@ class CocoVidAnnotation(CocoAnnotation):
"""
def __init__(
- self, bbox=None, category_id=None, category_name=None, image_id=None, instance_id=None, iscrowd=0, id=None,
+ self,
+ bbox=None,
+ category_id=None,
+ category_name=None,
+ image_id=None,
+ instance_id=None,
+ iscrowd=0,
+ id=None,
):
"""
Args:
@@ -445,7 +491,11 @@ def __init__(
Annotation id
"""
super(CocoVidAnnotation, self).__init__(
- bbox=bbox, category_id=category_id, category_name=category_name, image_id=image_id, iscrowd=iscrowd,
+ bbox=bbox,
+ category_id=category_id,
+ category_name=category_name,
+ image_id=image_id,
+ iscrowd=iscrowd,
)
self.instance_id = instance_id
self.id = id
@@ -549,7 +599,13 @@ class CocoVidImage(CocoImage):
"""
def __init__(
- self, file_name, height, width, video_id=None, frame_id=None, id=None,
+ self,
+ file_name,
+ height,
+ width,
+ video_id=None,
+ frame_id=None,
+ id=None,
):
"""
Creates CocoVidImage object
@@ -631,7 +687,12 @@ class CocoVideo:
"""
def __init__(
- self, name: str, id: int = None, fps: float = None, height: int = None, width: int = None,
+ self,
+ name: str,
+ id: int = None,
+ fps: float = None,
+ height: int = None,
+ width: int = None,
):
"""
Creates CocoVideo object
@@ -699,7 +760,11 @@ def __repr__(self):
class Coco:
def __init__(
- self, name=None, image_dir=None, remapping_dict=None, ignore_negative_samples=False,
+ self,
+ name=None,
+ image_dir=None,
+ remapping_dict=None,
+ ignore_negative_samples=False,
):
"""
Creates Coco object.
@@ -869,7 +934,8 @@ def merge(self, coco, desired_name2id=None, verbose=1):
# print categories
if verbose:
print(
- "Categories are formed as:\n", self.json_categories,
+ "Categories are formed as:\n",
+ self.json_categories,
)
@classmethod
@@ -899,7 +965,11 @@ def from_coco_dict_or_path(
category_mapping: dict
"""
# init coco object
- coco = cls(image_dir=image_dir, remapping_dict=remapping_dict, ignore_negative_samples=ignore_negative_samples,)
+ coco = cls(
+ image_dir=image_dir,
+ remapping_dict=remapping_dict,
+ ignore_negative_samples=ignore_negative_samples,
+ )
assert (
type(coco_dict_or_path) == str or type(coco_dict_or_path) == dict
@@ -966,7 +1036,9 @@ def category_mapping(self):
@property
def json(self):
return create_coco_dict(
- images=self.images, categories=self.json_categories, ignore_negative_samples=self.ignore_negative_samples,
+ images=self.images,
+ categories=self.json_categories,
+ ignore_negative_samples=self.ignore_negative_samples,
)
@property
@@ -1072,7 +1144,10 @@ def split_coco_as_train_val(self, train_split_rate=0.9, numpy_seed=0):
val_images = shuffled_images[num_train:]
# form train val coco objects
- train_coco = Coco(name=self.name if self.name else "split" + "_train", image_dir=self.image_dir,)
+ train_coco = Coco(
+ name=self.name if self.name else "split" + "_train",
+ image_dir=self.image_dir,
+ )
train_coco.images = train_images
train_coco.categories = self.categories
@@ -1123,7 +1198,10 @@ def export_as_yolov5(self, output_dir, train_split_rate=1, numpy_seed=0, mp=Fals
# split dataset
if split_mode == "TRAINVAL":
- result = self.split_coco_as_train_val(train_split_rate=train_split_rate, numpy_seed=numpy_seed,)
+ result = self.split_coco_as_train_val(
+ train_split_rate=train_split_rate,
+ numpy_seed=numpy_seed,
+ )
train_coco = result["train_coco"]
val_coco = result["val_coco"]
elif split_mode == "TRAIN":
@@ -1146,11 +1224,17 @@ def export_as_yolov5(self, output_dir, train_split_rate=1, numpy_seed=0, mp=Fals
# create image symlinks and annotation txts
if split_mode in ["TRAINVAL", "TRAIN"]:
export_yolov5_images_and_txts_from_coco_object(
- output_dir=train_dir, coco=train_coco, ignore_negative_samples=self.ignore_negative_samples, mp=mp,
+ output_dir=train_dir,
+ coco=train_coco,
+ ignore_negative_samples=self.ignore_negative_samples,
+ mp=mp,
)
if split_mode in ["TRAINVAL", "VAL"]:
export_yolov5_images_and_txts_from_coco_object(
- output_dir=val_dir, coco=val_coco, ignore_negative_samples=self.ignore_negative_samples, mp=mp,
+ output_dir=val_dir,
+ coco=val_coco,
+ ignore_negative_samples=self.ignore_negative_samples,
+ mp=mp,
)
# create yolov5 data yaml
@@ -1248,7 +1332,8 @@ def export_yolov5_images_and_txts_from_coco_object(output_dir, coco, ignore_nega
with Pool(processes=48) as pool:
args = [(coco_image, coco.image_dir, output_dir, ignore_negative_samples) for coco_image in coco.images]
pool.starmap(
- export_single_yolov5_image_and_corresponding_txt, tqdm(args, total=len(args)),
+ export_single_yolov5_image_and_corresponding_txt,
+ tqdm(args, total=len(args)),
)
else:
for coco_image in tqdm(coco.images):
@@ -1292,7 +1377,8 @@ def export_single_yolov5_image_and_corresponding_txt(
name_increment = 2
while Path(yolo_image_path).is_file():
yolo_image_path = yolo_image_path_temp.replace(
- Path(coco_image.file_name).stem, Path(coco_image.file_name).stem + "_" + str(name_increment),
+ Path(coco_image.file_name).stem,
+ Path(coco_image.file_name).stem + "_" + str(name_increment),
)
name_increment += 1
# create a symbolic link pointing to coco_image_path named yolo_image_path
@@ -1500,7 +1586,8 @@ def merge_from_list(coco_dict_list, desired_name2id=None, verbose=1):
# print categories
if verbose:
print(
- "Categories are formed as:\n", merged_coco_dict["categories"],
+ "Categories are formed as:\n",
+ merged_coco_dict["categories"],
)
return merged_coco_dict
@@ -1531,7 +1618,9 @@ def merge_from_file(coco_path1: str, coco_path2: str, save_path: str):
save_json(merged_coco_dict, save_path)
-def get_imageid2annotationlist_mapping(coco_dict: dict,) -> Dict[int, List[CocoAnnotation]]:
+def get_imageid2annotationlist_mapping(
+ coco_dict: dict,
+) -> Dict[int, List[CocoAnnotation]]:
"""
Get image_id to annotationlist mapping for faster indexing.
@@ -1638,7 +1727,11 @@ def create_coco_dict(images, categories, ignore_negative_samples=False):
def split_coco_as_train_val(
- coco_file_path_or_dict, file_name=None, target_dir=None, train_split_rate=0.9, numpy_seed=0,
+ coco_file_path_or_dict,
+ file_name=None,
+ target_dir=None,
+ train_split_rate=0.9,
+ numpy_seed=0,
):
"""
Takes single coco dataset file path, split images into train-val and saves as seperate coco dataset files.
@@ -1732,7 +1825,10 @@ def split_coco_as_train_val(
def add_bbox_and_area_to_coco(
- source_coco_path: str = "", target_coco_path: str = "", add_bbox: bool = True, add_area: bool = True,
+ source_coco_path: str = "",
+ target_coco_path: str = "",
+ add_bbox: bool = True,
+ add_area: bool = True,
) -> dict:
"""
Takes single coco dataset file path, calculates and fills bbox and area fields of the annotations
diff --git a/sahi/utils/cv.py b/sahi/utils/cv.py
index eb5f8ffc0..892523dcc 100644
--- a/sahi/utils/cv.py
+++ b/sahi/utils/cv.py
@@ -5,10 +5,12 @@
import os
import random
import time
-from typing import Union, Optional, List
+from typing import List, Optional, Union
+
import cv2
import numpy as np
from PIL import Image
+
from sahi.utils.file import create_dir
@@ -37,9 +39,16 @@ def crop_object_predictions(
category_id = object_prediction.category.id
# crop detections
# deepcopy crops so that original is not altered
- cropped_img = copy.deepcopy(image[int(bbox[1]) : int(bbox[3]), int(bbox[0]) : int(bbox[2]), :,])
+ cropped_img = copy.deepcopy(
+ image[
+ int(bbox[1]) : int(bbox[3]),
+ int(bbox[0]) : int(bbox[2]),
+ :,
+ ]
+ )
save_path = os.path.join(
- output_dir, file_name + "_box" + str(ind) + "_class" + str(category_id) + "." + export_format,
+ output_dir,
+ file_name + "_box" + str(ind) + "_class" + str(category_id) + "." + export_format,
)
cv2.imwrite(save_path, cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR))
@@ -178,7 +187,11 @@ def visualize_prediction(
image = cv2.addWeighted(image, 1, rgb_mask, 0.7, 0)
# visualize boxes
cv2.rectangle(
- image, tuple(box[0:2]), tuple(box[2:4]), color=color, thickness=rect_th,
+ image,
+ tuple(box[0:2]),
+ tuple(box[2:4]),
+ color=color,
+ thickness=rect_th,
)
# arange bounding box text location
if box[1] - 10 > 10:
@@ -187,7 +200,13 @@ def visualize_prediction(
box[1] += 10
# add bounding box text
cv2.putText(
- image, class_, tuple(box[0:2]), cv2.FONT_HERSHEY_SIMPLEX, text_size, color, thickness=text_th,
+ image,
+ class_,
+ tuple(box[0:2]),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ text_size,
+ color,
+ thickness=text_th,
)
if output_dir:
# create output folder if not present
@@ -247,7 +266,11 @@ def visualize_object_predictions(
image = cv2.addWeighted(image, 1, rgb_mask, 0.4, 0)
# visualize boxes
cv2.rectangle(
- image, tuple(bbox[0:2]), tuple(bbox[2:4]), color=color, thickness=rect_th,
+ image,
+ tuple(bbox[0:2]),
+ tuple(bbox[2:4]),
+ color=color,
+ thickness=rect_th,
)
# arange bounding box text location
if bbox[1] - 5 > 5:
@@ -257,7 +280,13 @@ def visualize_object_predictions(
# add bounding box text
label = "%s %.2f" % (category_name, score)
cv2.putText(
- image, label, tuple(bbox[0:2]), cv2.FONT_HERSHEY_SIMPLEX, text_size, color, thickness=text_th,
+ image,
+ label,
+ tuple(bbox[0:2]),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ text_size,
+ color,
+ thickness=text_th,
)
if output_dir:
# create output folder if not present
diff --git a/sahi/utils/fiftyone.py b/sahi/utils/fiftyone.py
index 782137e6d..253521400 100644
--- a/sahi/utils/fiftyone.py
+++ b/sahi/utils/fiftyone.py
@@ -3,7 +3,7 @@
try:
import fiftyone as fo
from fiftyone.utils.coco import COCODetectionDatasetImporter as BaseCOCODetectionDatasetImporter
- from fiftyone.utils.coco import load_coco_detection_annotations, _parse_label_types, _get_matching_image_ids
+ from fiftyone.utils.coco import _get_matching_image_ids, _parse_label_types, load_coco_detection_annotations
except ImportError:
raise ImportError('Please run "pip install -U fiftyone" to install fiftyone first for fiftyone utilities.')
@@ -26,9 +26,17 @@ def __init__(
seed=None,
max_samples=None,
):
- data_path = self._parse_data_path(dataset_dir=dataset_dir, data_path=data_path, default="data/",)
+ data_path = self._parse_data_path(
+ dataset_dir=dataset_dir,
+ data_path=data_path,
+ default="data/",
+ )
- labels_path = self._parse_labels_path(dataset_dir=dataset_dir, labels_path=labels_path, default="labels.json",)
+ labels_path = self._parse_labels_path(
+ dataset_dir=dataset_dir,
+ labels_path=labels_path,
+ default="labels.json",
+ )
label_types = _parse_label_types(label_types)
@@ -36,7 +44,10 @@ def __init__(
label_types.append("coco_id")
super().__init__(
- dataset_dir=dataset_dir, shuffle=shuffle, seed=seed, max_samples=max_samples,
+ dataset_dir=dataset_dir,
+ shuffle=shuffle,
+ seed=seed,
+ max_samples=max_samples,
)
self.data_path = data_path
@@ -60,9 +71,13 @@ def __init__(
def setup(self):
if self.labels_path is not None and os.path.isfile(self.labels_path):
- (info, classes, supercategory_map, images, annotations,) = load_coco_detection_annotations(
- self.labels_path, extra_attrs=self.extra_attrs
- )
+ (
+ info,
+ classes,
+ supercategory_map,
+ images,
+ annotations,
+ ) = load_coco_detection_annotations(self.labels_path, extra_attrs=self.extra_attrs)
if classes is not None:
info["classes"] = classes
diff --git a/sahi/utils/file.py b/sahi/utils/file.py
index 35d4f32d0..98cbeb924 100644
--- a/sahi/utils/file.py
+++ b/sahi/utils/file.py
@@ -7,10 +7,10 @@
import os
import pickle
import re
+import urllib.request
import zipfile
-from pathlib import Path
from os import path
-import urllib.request
+from pathlib import Path
import numpy as np
@@ -77,7 +77,11 @@ def load_json(load_path):
return data
-def list_files(directory: str, contains: list = [".json"], verbose: int = 1,) -> list:
+def list_files(
+ directory: str,
+ contains: list = [".json"],
+ verbose: int = 1,
+) -> list:
"""
Walk given directory and return a list of file path with desired extension
@@ -234,5 +238,6 @@ def download_from_url(from_url: str, to_path: str):
if not path.exists(to_path):
urllib.request.urlretrieve(
- from_url, to_path,
+ from_url,
+ to_path,
)
diff --git a/sahi/utils/mmdet.py b/sahi/utils/mmdet.py
index d4c2ed4c0..00e68dc17 100644
--- a/sahi/utils/mmdet.py
+++ b/sahi/utils/mmdet.py
@@ -1,10 +1,10 @@
+import shutil
+import sys
import urllib.request
+from importlib import import_module
from os import path
from pathlib import Path
from typing import Optional
-from importlib import import_module
-import shutil
-import sys
def mmdet_version_as_integer():
@@ -45,7 +45,8 @@ def download_mmdet_cascade_mask_rcnn_model(destination_path: Optional[str] = Non
if not path.exists(destination_path):
urllib.request.urlretrieve(
- MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_URL, destination_path,
+ MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_URL,
+ destination_path,
)
@@ -58,7 +59,8 @@ def download_mmdet_retinanet_model(destination_path: Optional[str] = None):
if not path.exists(destination_path):
urllib.request.urlretrieve(
- MmdetTestConstants.MMDET_RETINANET_MODEL_URL, destination_path,
+ MmdetTestConstants.MMDET_RETINANET_MODEL_URL,
+ destination_path,
)
@@ -106,7 +108,8 @@ def download_mmdet_config(
# download main config file
urllib.request.urlretrieve(
- main_config_url, main_config_path,
+ main_config_url,
+ main_config_path,
)
# read main config file
@@ -127,7 +130,8 @@ def download_mmdet_config(
# download secondary config files
urllib.request.urlretrieve(
- config_url, str(config_path),
+ config_url,
+ str(config_path),
)
# set final config dirs
@@ -161,5 +165,7 @@ def download_mmdet_config(
if __name__ == "__main__":
download_mmdet_config(
- model_name="cascade_rcnn", config_file_name="cascade_mask_rcnn_r50_fpn_1x_coco.py", verbose=False,
+ model_name="cascade_rcnn",
+ config_file_name="cascade_mask_rcnn_r50_fpn_1x_coco.py",
+ verbose=False,
)
diff --git a/sahi/utils/mot.py b/sahi/utils/mot.py
index 6ccaa3639..437bb8b89 100644
--- a/sahi/utils/mot.py
+++ b/sahi/utils/mot.py
@@ -1,17 +1,17 @@
-import os
import copy
+import os
from pathlib import Path
-from typing import Optional, List, Dict
+from typing import Dict, List, Optional
import numpy as np
-from sahi.utils.file import increment_path
+from sahi.utils.file import increment_path
try:
import norfair
- from norfair import Tracker, Detection
- from norfair.tracker import TrackedObject, FilterSetup
- from norfair.metrics import PredictionsTextFile, InformationFile
+ from norfair import Detection, Tracker
+ from norfair.metrics import InformationFile, PredictionsTextFile
+ from norfair.tracker import FilterSetup, TrackedObject
except ImportError:
raise ImportError('Please run "pip install -U norfair" to install norfair first for MOT format handling.')
diff --git a/sahi/utils/yolov5.py b/sahi/utils/yolov5.py
index 4835196c4..fffc5757c 100644
--- a/sahi/utils/yolov5.py
+++ b/sahi/utils/yolov5.py
@@ -21,5 +21,6 @@ def download_yolov5s6_model(destination_path: Optional[str] = None):
if not path.exists(destination_path):
urllib.request.urlretrieve(
- Yolov5TestConstants.YOLOV5S6_MODEL_URL, destination_path,
+ Yolov5TestConstants.YOLOV5S6_MODEL_URL,
+ destination_path,
)
diff --git a/scripts/coco2yolov5.py b/scripts/coco2yolov5.py
index 39148b2c4..e82aa24b2 100644
--- a/scripts/coco2yolov5.py
+++ b/scripts/coco2yolov5.py
@@ -7,10 +7,16 @@
parser = argparse.ArgumentParser()
parser.add_argument("--source", type=str, default="", help="directory for coco images")
parser.add_argument(
- "--coco_file", type=str, default=None, help="file path for the coco file to be converted",
+ "--coco_file",
+ type=str,
+ default=None,
+ help="file path for the coco file to be converted",
)
parser.add_argument(
- "--train_split", type=float, default=0.9, help="set the training split ratio",
+ "--train_split",
+ type=float,
+ default=0.9,
+ help="set the training split ratio",
)
parser.add_argument("--project", default="runs/coco2yolov5", help="save results to project/name")
parser.add_argument("--name", default="exp", help="save results to project/name")
@@ -21,8 +27,13 @@
# increment run
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=False))
# load coco dict
- coco = Coco.from_coco_dict_or_path(coco_dict_or_path=opt.coco_file, image_dir=opt.source,)
+ coco = Coco.from_coco_dict_or_path(
+ coco_dict_or_path=opt.coco_file,
+ image_dir=opt.source,
+ )
# export as yolov5
coco.export_as_yolov5(
- output_dir=str(save_dir), train_split_rate=opt.train_split, numpy_seed=opt.seed,
+ output_dir=str(save_dir),
+ train_split_rate=opt.train_split,
+ numpy_seed=opt.seed,
)
diff --git a/scripts/coco_error_analysis.py b/scripts/coco_error_analysis.py
index f8c7fcce9..9a4753937 100644
--- a/scripts/coco_error_analysis.py
+++ b/scripts/coco_error_analysis.py
@@ -41,7 +41,11 @@ def makeplot(rs, ps, outDir, class_name, iou_type):
for k in range(len(types)):
ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5)
ax.fill_between(
- rs, ps_curve[k], ps_curve[k + 1], color=cs[k], label=str(f"[{aps[k]:.3f}]" + types[k]),
+ rs,
+ ps_curve[k],
+ ps_curve[k + 1],
+ color=cs[k],
+ label=str(f"[{aps[k]:.3f}]" + types[k]),
)
plt.xlabel("recall")
plt.ylabel("precision")
@@ -85,7 +89,12 @@ def makebarplot(rs, ps, outDir, class_name, iou_type):
type_ps = ps[i, ..., 0]
aps = [ps_.mean() for ps_ in type_ps.T]
rects_list.append(
- ax.bar(x - width / 2 + (i + 1) * width / len(types), aps, width / len(types), label=types[i],)
+ ax.bar(
+ x - width / 2 + (i + 1) * width / len(types),
+ aps,
+ width / len(types),
+ label=types[i],
+ )
)
# Add some text for labels, title and custom x-axis tick labels, etc.
@@ -315,7 +324,11 @@ def main():
parser.add_argument("--types", type=str, nargs="+", default=["bbox"], help="result types")
parser.add_argument("--extraplots", action="store_true", help="export extra bar/stat plots")
parser.add_argument(
- "--areas", type=int, nargs="+", default=[1024, 9216, 10000000000], help="area regions",
+ "--areas",
+ type=int,
+ nargs="+",
+ default=[1024, 9216, 10000000000],
+ help="area regions",
)
args = parser.parse_args()
diff --git a/scripts/coco_evaluation.py b/scripts/coco_evaluation.py
index 81f74aeff..c88bb9eee 100644
--- a/scripts/coco_evaluation.py
+++ b/scripts/coco_evaluation.py
@@ -1,12 +1,12 @@
-from argparse import ArgumentParser
-from pathlib import Path
-import json
import itertools
+import json
import warnings
+from argparse import ArgumentParser
from collections import OrderedDict
-from terminaltables import AsciiTable
+from pathlib import Path
import numpy as np
+from terminaltables import AsciiTable
try:
from pycocotools.coco import COCO
diff --git a/scripts/predict.py b/scripts/predict.py
index 7426483ca..956887618 100644
--- a/scripts/predict.py
+++ b/scripts/predict.py
@@ -12,16 +12,28 @@
help="mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'",
)
parser.add_argument(
- "--model_path", type=str, default="", help="path for the model",
+ "--model_path",
+ type=str,
+ default="",
+ help="path for the model",
)
parser.add_argument(
- "--config_path", type=str, default="", help="path for the model config",
+ "--config_path",
+ type=str,
+ default="",
+ help="path for the model config",
)
parser.add_argument(
- "--conf_thresh", type=float, default=0.25, help="all predictions with score < conf_thresh will be discarded",
+ "--conf_thresh",
+ type=float,
+ default=0.25,
+ help="all predictions with score < conf_thresh will be discarded",
)
parser.add_argument(
- "--device", type=str, default=None, help="cpu or cuda",
+ "--device",
+ type=str,
+ default=None,
+ help="cpu or cuda",
)
parser.add_argument(
"--category_mapping",
@@ -58,7 +70,9 @@
parser.add_argument("--match_metric", type=str, default="IOS", help="match metric for postprocess: 'IOU' or 'IOS'")
parser.add_argument("--match_thresh", type=float, default=0.5, help="match threshold for postprocess")
parser.add_argument(
- "--class_agnostic", action="store_true", help="Postprocess will ignore category ids.",
+ "--class_agnostic",
+ action="store_true",
+ help="Postprocess will ignore category ids.",
)
parser.add_argument("--visual_export_format", type=str, default="png")
diff --git a/scripts/predict_fiftyone.py b/scripts/predict_fiftyone.py
index 3fc5e7d65..a85abfe62 100644
--- a/scripts/predict_fiftyone.py
+++ b/scripts/predict_fiftyone.py
@@ -18,16 +18,28 @@
help="mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'",
)
parser.add_argument(
- "--model_path", type=str, default="", help="path for the model",
+ "--model_path",
+ type=str,
+ default="",
+ help="path for the model",
)
parser.add_argument(
- "--config_path", type=str, default="", help="path for the model config",
+ "--config_path",
+ type=str,
+ default="",
+ help="path for the model config",
)
parser.add_argument(
- "--conf_thresh", type=float, default=0.25, help="all predictions with score < conf_thresh will be discarded",
+ "--conf_thresh",
+ type=float,
+ default=0.25,
+ help="all predictions with score < conf_thresh will be discarded",
)
parser.add_argument(
- "--device", type=str, default=None, help="cpu or cuda",
+ "--device",
+ type=str,
+ default=None,
+ help="cpu or cuda",
)
parser.add_argument(
"--category_mapping",
@@ -53,7 +65,9 @@
parser.add_argument("--match_metric", type=str, default="IOS", help="match metric for postprocess: 'IOU' or 'IOS'")
parser.add_argument("--match_thresh", type=float, default=0.5, help="match threshold for postprocess")
parser.add_argument(
- "--class_agnostic", action="store_true", help="Postprocess will ignore category ids.",
+ "--class_agnostic",
+ action="store_true",
+ help="Postprocess will ignore category ids.",
)
opt = parser.parse_args()
diff --git a/scripts/slice_coco.py b/scripts/slice_coco.py
index 269c0a853..89c004636 100644
--- a/scripts/slice_coco.py
+++ b/scripts/slice_coco.py
@@ -1,15 +1,17 @@
-import os
import argparse
+import os
from sahi.slicing import slice_coco
-from sahi.utils.coco import split_coco_as_train_val, Coco
-from sahi.utils.file import get_base_filename, save_json, Path, increment_path
-
+from sahi.utils.coco import Coco, split_coco_as_train_val
+from sahi.utils.file import Path, get_base_filename, increment_path, save_json
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- "coco_json_path", type=str, default=None, help="path to coco annotation json file",
+ "coco_json_path",
+ type=str,
+ default=None,
+ help="path to coco annotation json file",
)
parser.add_argument("coco_image_dir", type=str, default="", help="folder containing coco images")
parser.add_argument("--slice_size", type=int, nargs="+", default=[512], help="slice size")
@@ -54,5 +56,6 @@
output_coco_annotation_file_path = os.path.join(output_dir, sliced_coco_name + ".json")
save_json(coco_dict, output_coco_annotation_file_path)
print(
- f"Sliced 'slice_size: {slice_size}' coco file is saved to", output_coco_annotation_file_path,
+ f"Sliced 'slice_size: {slice_size}' coco file is saved to",
+ output_coco_annotation_file_path,
)
diff --git a/setup.py b/setup.py
index 8f6aa3e37..d19f4897f 100644
--- a/setup.py
+++ b/setup.py
@@ -37,7 +37,7 @@ def get_version():
install_requires=get_requirements(),
extras_require={
"tests": ["pytest", "mmdet", "norfair"],
- "dev": ["black==21.5b1", "flake==3.9.2", "isort==5.8.0", "jupyterlab==3.0.14"],
+ "dev": ["black==21.7b0", "flake8==3.9.2", "isort==5.9.2", "jupyterlab==3.0.14"],
},
classifiers=[
"License :: OSI Approved :: MIT License",
diff --git a/tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn.py b/tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn.py
index 033a88843..2f6cba0f6 100644
--- a/tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn.py
+++ b/tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn.py
@@ -17,9 +17,16 @@
type="RPNHead",
in_channels=256,
feat_channels=256,
- anchor_generator=dict(type="AnchorGenerator", scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64],),
+ anchor_generator=dict(
+ type="AnchorGenerator",
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64],
+ ),
bbox_coder=dict(
- type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0],
+ type="DeltaXYWHBBoxCoder",
+ target_means=[0.0, 0.0, 0.0, 0.0],
+ target_stds=[1.0, 1.0, 1.0, 1.0],
),
loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type="SmoothL1Loss", beta=1.0 / 9.0, loss_weight=1.0),
@@ -42,7 +49,9 @@
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
- type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2],
+ type="DeltaXYWHBBoxCoder",
+ target_means=[0.0, 0.0, 0.0, 0.0],
+ target_stds=[0.1, 0.1, 0.2, 0.2],
),
reg_class_agnostic=True,
loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0),
@@ -55,7 +64,9 @@
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
- type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.05, 0.05, 0.1, 0.1],
+ type="DeltaXYWHBBoxCoder",
+ target_means=[0.0, 0.0, 0.0, 0.0],
+ target_stds=[0.05, 0.05, 0.1, 0.1],
),
reg_class_agnostic=True,
loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0),
@@ -103,7 +114,13 @@
match_low_quality=True,
ignore_iof_thr=-1,
),
- sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False,),
+ sampler=dict(
+ type="RandomSampler",
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False,
+ ),
allowed_border=0,
pos_weight=-1,
debug=False,
@@ -127,7 +144,11 @@
ignore_iof_thr=-1,
),
sampler=dict(
- type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True,
+ type="RandomSampler",
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
),
mask_size=28,
pos_weight=-1,
@@ -143,7 +164,11 @@
ignore_iof_thr=-1,
),
sampler=dict(
- type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True,
+ type="RandomSampler",
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
),
mask_size=28,
pos_weight=-1,
@@ -159,7 +184,11 @@
ignore_iof_thr=-1,
),
sampler=dict(
- type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True,
+ type="RandomSampler",
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
),
mask_size=28,
pos_weight=-1,
@@ -176,6 +205,11 @@
nms=dict(type="nms", iou_threshold=0.7),
min_bbox_size=0,
),
- rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5,),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type="nms", iou_threshold=0.5),
+ max_per_img=100,
+ mask_thr_binary=0.5,
+ ),
),
)
diff --git a/tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn_v280.py b/tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn_v280.py
index f2ac9c894..70f4afd21 100644
--- a/tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn_v280.py
+++ b/tests/data/models/mmdet_cascade_mask_rcnn/cascade_mask_rcnn_r50_fpn_v280.py
@@ -17,9 +17,16 @@
type="RPNHead",
in_channels=256,
feat_channels=256,
- anchor_generator=dict(type="AnchorGenerator", scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64],),
+ anchor_generator=dict(
+ type="AnchorGenerator",
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64],
+ ),
bbox_coder=dict(
- type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0],
+ type="DeltaXYWHBBoxCoder",
+ target_means=[0.0, 0.0, 0.0, 0.0],
+ target_stds=[1.0, 1.0, 1.0, 1.0],
),
loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type="SmoothL1Loss", beta=1.0 / 9.0, loss_weight=1.0),
@@ -42,7 +49,9 @@
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
- type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.1, 0.1, 0.2, 0.2],
+ type="DeltaXYWHBBoxCoder",
+ target_means=[0.0, 0.0, 0.0, 0.0],
+ target_stds=[0.1, 0.1, 0.2, 0.2],
),
reg_class_agnostic=True,
loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0),
@@ -55,7 +64,9 @@
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
- type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[0.05, 0.05, 0.1, 0.1],
+ type="DeltaXYWHBBoxCoder",
+ target_means=[0.0, 0.0, 0.0, 0.0],
+ target_stds=[0.05, 0.05, 0.1, 0.1],
),
reg_class_agnostic=True,
loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0),
@@ -104,13 +115,24 @@
match_low_quality=True,
ignore_iof_thr=-1,
),
- sampler=dict(type="RandomSampler", num=256, pos_fraction=0.5, neg_pos_ub=-1, add_gt_as_proposals=False,),
+ sampler=dict(
+ type="RandomSampler",
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False,
+ ),
allowed_border=0,
pos_weight=-1,
debug=False,
),
rpn_proposal=dict(
- nms_across_levels=False, nms_pre=2000, nms_post=2000, max_num=2000, nms_thr=0.7, min_bbox_size=0,
+ nms_across_levels=False,
+ nms_pre=2000,
+ nms_post=2000,
+ max_num=2000,
+ nms_thr=0.7,
+ min_bbox_size=0,
),
rcnn=[
dict(
@@ -122,7 +144,13 @@
match_low_quality=False,
ignore_iof_thr=-1,
),
- sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True,),
+ sampler=dict(
+ type="RandomSampler",
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ ),
mask_size=28,
pos_weight=-1,
debug=False,
@@ -136,7 +164,13 @@
match_low_quality=False,
ignore_iof_thr=-1,
),
- sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True,),
+ sampler=dict(
+ type="RandomSampler",
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ ),
mask_size=28,
pos_weight=-1,
debug=False,
@@ -150,7 +184,13 @@
match_low_quality=False,
ignore_iof_thr=-1,
),
- sampler=dict(type="RandomSampler", num=512, pos_fraction=0.25, neg_pos_ub=-1, add_gt_as_proposals=True,),
+ sampler=dict(
+ type="RandomSampler",
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ ),
mask_size=28,
pos_weight=-1,
debug=False,
@@ -158,6 +198,18 @@
],
)
test_cfg = dict(
- rpn=dict(nms_across_levels=False, nms_pre=1000, nms_post=1000, max_num=1000, nms_thr=0.7, min_bbox_size=0,),
- rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100, mask_thr_binary=0.5,),
+ rpn=dict(
+ nms_across_levels=False,
+ nms_pre=1000,
+ nms_post=1000,
+ max_num=1000,
+ nms_thr=0.7,
+ min_bbox_size=0,
+ ),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type="nms", iou_threshold=0.5),
+ max_per_img=100,
+ mask_thr_binary=0.5,
+ ),
)
diff --git a/tests/test_cocoutils.py b/tests/test_cocoutils.py
index 740e6e589..9c70167e7 100644
--- a/tests/test_cocoutils.py
+++ b/tests/test_cocoutils.py
@@ -19,7 +19,11 @@ def test_coco_categories(self):
coco_category1 = CocoCategory(id=category_id, name=category_name, supercategory=supercategory)
coco_category2 = CocoCategory(id=category_id, name=category_name)
coco_category3 = CocoCategory.from_coco_category(
- {"id": category_id, "name": category_name, "supercategory": supercategory,}
+ {
+ "id": category_id,
+ "name": category_name,
+ "supercategory": supercategory,
+ }
)
self.assertEqual(coco_category1.id, category_id)
@@ -45,7 +49,9 @@ def test_coco_annotation(self):
category_id = 3
category_name = "car"
coco_annotation = CocoAnnotation.from_coco_segmentation(
- segmentation=coco_segmentation, category_id=category_id, category_name=category_name,
+ segmentation=coco_segmentation,
+ category_id=category_id,
+ category_name=category_name,
)
self.assertAlmostEqual(coco_annotation.area, 41177, 1)
@@ -57,7 +63,9 @@ def test_coco_annotation(self):
coco_bbox = [1, 1, 100, 100]
category_id = 3
coco_annotation = CocoAnnotation.from_coco_bbox(
- bbox=coco_bbox, category_id=category_id, category_name=category_name,
+ bbox=coco_bbox,
+ category_id=category_id,
+ category_name=category_name,
)
self.assertEqual(coco_annotation.area, 10000)
@@ -105,7 +113,9 @@ def test_coco_image(self):
category_id = 3
category_name = "car"
coco_annotation_1 = CocoAnnotation.from_coco_segmentation(
- segmentation=coco_segmentation, category_id=category_id, category_name=category_name,
+ segmentation=coco_segmentation,
+ category_id=category_id,
+ category_name=category_name,
)
coco_image.add_annotation(coco_annotation_1)
@@ -259,25 +269,32 @@ def test_coco(self):
self.assertEqual(coco1.images[2].annotations[1].category_name, "human")
self.assertEqual(coco2.images[2].annotations[1].category_name, "human")
self.assertEqual(
- coco1.images[1].annotations[1].segmentation, [[501, 451, 622, 451, 622, 543, 501, 543]],
+ coco1.images[1].annotations[1].segmentation,
+ [[501, 451, 622, 451, 622, 543, 501, 543]],
)
self.assertEqual(
- coco2.images[1].annotations[1].segmentation, [[501, 451, 622, 451, 622, 543, 501, 543]],
+ coco2.images[1].annotations[1].segmentation,
+ [[501, 451, 622, 451, 622, 543, 501, 543]],
)
self.assertEqual(
- coco1.category_mapping, category_mapping,
+ coco1.category_mapping,
+ category_mapping,
)
self.assertEqual(
- coco2.category_mapping, category_mapping,
+ coco2.category_mapping,
+ category_mapping,
)
self.assertEqual(
- coco1.stats, coco2.stats,
+ coco1.stats,
+ coco2.stats,
)
self.assertEqual(
- coco1.stats["num_images"], len(coco1.images),
+ coco1.stats["num_images"],
+ len(coco1.images),
)
self.assertEqual(
- coco1.stats["num_annotations"], len(coco1.json["annotations"]),
+ coco1.stats["num_annotations"],
+ len(coco1.json["annotations"]),
)
def test_split_coco_as_train_val(self):
@@ -293,7 +310,8 @@ def test_split_coco_as_train_val(self):
self.assertEqual(result["train_coco"].image_dir, image_dir)
self.assertEqual(result["train_coco"].stats["num_images"], len(result["train_coco"].images))
self.assertEqual(
- result["train_coco"].stats["num_annotations"], len(result["train_coco"].json["annotations"]),
+ result["train_coco"].stats["num_annotations"],
+ len(result["train_coco"].json["annotations"]),
)
self.assertEqual(len(result["val_coco"].json["images"]), 1)
@@ -302,7 +320,8 @@ def test_split_coco_as_train_val(self):
self.assertEqual(result["val_coco"].image_dir, image_dir)
self.assertEqual(result["val_coco"].stats["num_images"], len(result["val_coco"].images))
self.assertEqual(
- result["val_coco"].stats["num_annotations"], len(result["val_coco"].json["annotations"]),
+ result["val_coco"].stats["num_annotations"],
+ len(result["val_coco"].json["annotations"]),
)
def test_coco2yolo(self):
@@ -324,7 +343,8 @@ def test_update_categories(self):
self.assertEqual(len(source_coco_dict["images"]), 1)
self.assertEqual(len(source_coco_dict["categories"]), 1)
self.assertEqual(
- source_coco_dict["categories"], [{"id": 1, "name": "car", "supercategory": "car"}],
+ source_coco_dict["categories"],
+ [{"id": 1, "name": "car", "supercategory": "car"}],
)
self.assertEqual(source_coco_dict["annotations"][1]["category_id"], 1)
@@ -356,7 +376,8 @@ def test_coco_update_categories(self):
self.assertEqual(len(coco.json["images"]), 1)
self.assertEqual(len(coco.json["categories"]), 1)
self.assertEqual(
- coco.json["categories"], [{"id": 1, "name": "car", "supercategory": "car"}],
+ coco.json["categories"],
+ [{"id": 1, "name": "car", "supercategory": "car"}],
)
self.assertEqual(coco.json["annotations"][1]["category_id"], 1)
self.assertEqual(coco.image_dir, image_dir)
@@ -460,19 +481,24 @@ def test_merge_from_list(self):
self.assertEqual(len(merged_coco_dict["annotations"]), 22)
self.assertEqual(len(merged_coco_dict["categories"]), 2)
self.assertEqual(
- merged_coco_dict["annotations"][12]["bbox"], coco_dict3["annotations"][0]["bbox"],
+ merged_coco_dict["annotations"][12]["bbox"],
+ coco_dict3["annotations"][0]["bbox"],
)
self.assertEqual(
- merged_coco_dict["annotations"][12]["id"], 13,
+ merged_coco_dict["annotations"][12]["id"],
+ 13,
)
self.assertEqual(
- merged_coco_dict["annotations"][12]["image_id"], 3,
+ merged_coco_dict["annotations"][12]["image_id"],
+ 3,
)
self.assertEqual(
- merged_coco_dict["annotations"][9]["category_id"], 1,
+ merged_coco_dict["annotations"][9]["category_id"],
+ 1,
)
self.assertEqual(
- merged_coco_dict["annotations"][9]["image_id"], 2,
+ merged_coco_dict["annotations"][9]["image_id"],
+ 2,
)
def test_coco_merge(self):
@@ -494,22 +520,28 @@ def test_coco_merge(self):
self.assertEqual(len(coco1.images), 3)
self.assertEqual(
- coco1.json["annotations"][12]["id"], 13,
+ coco1.json["annotations"][12]["id"],
+ 13,
)
self.assertEqual(
- coco1.json["annotations"][12]["image_id"], 3,
+ coco1.json["annotations"][12]["image_id"],
+ 3,
)
self.assertEqual(
- coco1.json["annotations"][9]["category_id"], 1,
+ coco1.json["annotations"][9]["category_id"],
+ 1,
)
self.assertEqual(
- coco1.json["annotations"][9]["image_id"], 2,
+ coco1.json["annotations"][9]["image_id"],
+ 2,
)
self.assertEqual(
- coco1.image_dir, image_dir,
+ coco1.image_dir,
+ image_dir,
)
self.assertEqual(
- coco2.image_dir, image_dir,
+ coco2.image_dir,
+ image_dir,
)
self.assertEqual(coco2.stats["num_images"], len(coco2.images))
self.assertEqual(coco2.stats["num_annotations"], len(coco2.json["annotations"]))
@@ -523,26 +555,33 @@ def test_get_subsampled_coco(self):
coco = Coco.from_coco_dict_or_path(coco_path, image_dir=image_dir)
subsampled_coco = coco.get_subsampled_coco(subsample_ratio=5)
self.assertEqual(
- len(coco.json["images"]), 50,
+ len(coco.json["images"]),
+ 50,
)
self.assertEqual(
- len(subsampled_coco.json["images"]), 10,
+ len(subsampled_coco.json["images"]),
+ 10,
)
self.assertEqual(
- len(coco.images[5].annotations), len(subsampled_coco.images[1].annotations),
+ len(coco.images[5].annotations),
+ len(subsampled_coco.images[1].annotations),
)
self.assertEqual(
- len(coco.images[5].annotations), len(subsampled_coco.images[1].annotations),
+ len(coco.images[5].annotations),
+ len(subsampled_coco.images[1].annotations),
)
self.assertEqual(
- coco.image_dir, image_dir,
+ coco.image_dir,
+ image_dir,
)
self.assertEqual(
- subsampled_coco.image_dir, image_dir,
+ subsampled_coco.image_dir,
+ image_dir,
)
self.assertEqual(subsampled_coco.stats["num_images"], len(subsampled_coco.images))
self.assertEqual(
- subsampled_coco.stats["num_annotations"], len(subsampled_coco.json["annotations"]),
+ subsampled_coco.stats["num_annotations"],
+ len(subsampled_coco.json["annotations"]),
)
def test_get_area_filtered_coco(self):
@@ -555,23 +594,29 @@ def test_get_area_filtered_coco(self):
coco = Coco.from_coco_dict_or_path(coco_path, image_dir=image_dir)
area_filtered_coco = coco.get_area_filtered_coco(min=min_area, max=max_area)
self.assertEqual(
- len(coco.json["images"]), 50,
+ len(coco.json["images"]),
+ 50,
)
self.assertEqual(
- len(area_filtered_coco.json["images"]), 15,
+ len(area_filtered_coco.json["images"]),
+ 15,
)
self.assertGreater(
- area_filtered_coco.stats["min_annotation_area"], min_area,
+ area_filtered_coco.stats["min_annotation_area"],
+ min_area,
)
self.assertLess(
- area_filtered_coco.stats["max_annotation_area"], max_area,
+ area_filtered_coco.stats["max_annotation_area"],
+ max_area,
)
self.assertEqual(
- area_filtered_coco.image_dir, image_dir,
+ area_filtered_coco.image_dir,
+ image_dir,
)
self.assertEqual(area_filtered_coco.stats["num_images"], len(area_filtered_coco.images))
self.assertEqual(
- area_filtered_coco.stats["num_annotations"], len(area_filtered_coco.json["annotations"]),
+ area_filtered_coco.stats["num_annotations"],
+ len(area_filtered_coco.json["annotations"]),
)
intervals_per_category = {
@@ -581,25 +626,35 @@ def test_get_area_filtered_coco(self):
area_filtered_coco = coco.get_area_filtered_coco(intervals_per_category=intervals_per_category)
self.assertEqual(
- len(coco.json["images"]), 50,
+ len(coco.json["images"]),
+ 50,
)
self.assertEqual(
- len(area_filtered_coco.json["images"]), 22,
+ len(area_filtered_coco.json["images"]),
+ 22,
)
self.assertGreater(
area_filtered_coco.stats["min_annotation_area"],
- min(intervals_per_category["human"]["min"], intervals_per_category["vehicle"]["min"],),
+ min(
+ intervals_per_category["human"]["min"],
+ intervals_per_category["vehicle"]["min"],
+ ),
)
self.assertLess(
area_filtered_coco.stats["max_annotation_area"],
- max(intervals_per_category["human"]["max"], intervals_per_category["vehicle"]["max"],),
+ max(
+ intervals_per_category["human"]["max"],
+ intervals_per_category["vehicle"]["max"],
+ ),
)
self.assertEqual(
- area_filtered_coco.image_dir, image_dir,
+ area_filtered_coco.image_dir,
+ image_dir,
)
self.assertEqual(area_filtered_coco.stats["num_images"], len(area_filtered_coco.images))
self.assertEqual(
- area_filtered_coco.stats["num_annotations"], len(area_filtered_coco.json["annotations"]),
+ area_filtered_coco.stats["num_annotations"],
+ len(area_filtered_coco.json["annotations"]),
)
intervals_per_category = {
@@ -609,25 +664,35 @@ def test_get_area_filtered_coco(self):
area_filtered_coco = coco.get_area_filtered_coco(intervals_per_category=intervals_per_category)
self.assertEqual(
- len(coco.json["images"]), 50,
+ len(coco.json["images"]),
+ 50,
)
self.assertEqual(
- len(area_filtered_coco.json["images"]), 22,
+ len(area_filtered_coco.json["images"]),
+ 22,
)
self.assertGreater(
area_filtered_coco.stats["min_annotation_area"],
- min(intervals_per_category["human"]["min"], intervals_per_category["vehicle"]["min"],),
+ min(
+ intervals_per_category["human"]["min"],
+ intervals_per_category["vehicle"]["min"],
+ ),
)
self.assertLess(
area_filtered_coco.stats["max_annotation_area"],
- max(intervals_per_category["human"]["max"], intervals_per_category["vehicle"]["max"],),
+ max(
+ intervals_per_category["human"]["max"],
+ intervals_per_category["vehicle"]["max"],
+ ),
)
self.assertEqual(
- area_filtered_coco.image_dir, image_dir,
+ area_filtered_coco.image_dir,
+ image_dir,
)
self.assertEqual(area_filtered_coco.stats["num_images"], len(area_filtered_coco.images))
self.assertEqual(
- area_filtered_coco.stats["num_annotations"], len(area_filtered_coco.json["annotations"]),
+ area_filtered_coco.stats["num_annotations"],
+ len(area_filtered_coco.json["annotations"]),
)
def test_cocovid(self):
diff --git a/tests/test_filter.py b/tests/test_filter.py
index 895301a8f..0dd498c7a 100644
--- a/tests/test_filter.py
+++ b/tests/test_filter.py
@@ -8,6 +8,7 @@
import cv2
import pytest
+
from sahi.annotation import BoundingBox
from sahi.postprocess.legacy.match import PredictionList, PredictionMatcher
from sahi.postprocess.legacy.merge import PredictionMerger, ScoreMergingPolicy
@@ -65,7 +66,9 @@ def _perturb(box: BoundingBox):
return BoundingBox(box=[minx, miny, maxx, maxy], shift_amount=box.shift_amount)
-def perturb_boxes(preds: List[ObjectPrediction],) -> List[ObjectPrediction]:
+def perturb_boxes(
+ preds: List[ObjectPrediction],
+) -> List[ObjectPrediction]:
preds = deepcopy(preds)
for i in range(len(preds)):
if i % 2 == 0:
diff --git a/tests/test_mmdetectionmodel.py b/tests/test_mmdetectionmodel.py
index d3e5287f9..bbc2041be 100644
--- a/tests/test_mmdetectionmodel.py
+++ b/tests/test_mmdetectionmodel.py
@@ -4,13 +4,9 @@
import unittest
import numpy as np
-from sahi.utils.cv import read_image
-from sahi.utils.mmdet import (
- MmdetTestConstants,
- download_mmdet_cascade_mask_rcnn_model,
- download_mmdet_retinanet_model,
-)
+from sahi.utils.cv import read_image
+from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model, download_mmdet_retinanet_model
class TestMmdetDetectionModel(unittest.TestCase):
@@ -135,17 +131,20 @@ def test_convert_original_predictions_with_mask_output(self):
self.assertEqual(object_prediction_list[0].category.id, 0)
self.assertEqual(object_prediction_list[0].category.name, "person")
self.assertEqual(
- object_prediction_list[0].bbox.to_coco_bbox(), [337, 124, 8, 14],
+ object_prediction_list[0].bbox.to_coco_bbox(),
+ [337, 124, 8, 14],
)
self.assertEqual(object_prediction_list[1].category.id, 2)
self.assertEqual(object_prediction_list[1].category.name, "car")
self.assertEqual(
- object_prediction_list[1].bbox.to_coco_bbox(), [657, 204, 13, 10],
+ object_prediction_list[1].bbox.to_coco_bbox(),
+ [657, 204, 13, 10],
)
self.assertEqual(object_prediction_list[5].category.id, 2)
self.assertEqual(object_prediction_list[5].category.name, "car")
self.assertEqual(
- object_prediction_list[2].bbox.to_coco_bbox(), [760, 232, 20, 15],
+ object_prediction_list[2].bbox.to_coco_bbox(),
+ [760, 232, 20, 15],
)
def test_convert_original_predictions_without_mask_output(self):
@@ -179,15 +178,19 @@ def test_convert_original_predictions_without_mask_output(self):
self.assertEqual(object_prediction_list[0].category.id, 2)
self.assertEqual(object_prediction_list[0].category.name, "car")
self.assertEqual(
- object_prediction_list[0].bbox.to_coco_bbox(), [448, 309, 47, 32],
+ object_prediction_list[0].bbox.to_coco_bbox(),
+ [448, 309, 47, 32],
)
self.assertEqual(object_prediction_list[5].category.id, 2)
self.assertEqual(object_prediction_list[5].category.name, "car")
self.assertEqual(
- object_prediction_list[5].bbox.to_coco_bbox(), [523, 225, 22, 17],
+ object_prediction_list[5].bbox.to_coco_bbox(),
+ [523, 225, 22, 17],
)
- def test_create_original_predictions_from_object_prediction_list_with_mask_output(self,):
+ def test_create_original_predictions_from_object_prediction_list_with_mask_output(
+ self,
+ ):
from sahi.model import MmdetDetectionModel
# init model
@@ -230,7 +233,9 @@ def test_create_original_predictions_from_object_prediction_list_with_mask_outpu
self.assertEqual(len(original_predictions_1[0][1]), len(original_predictions_1[0][1])) # 0
self.assertEqual(original_predictions_1[0][1].shape, original_predictions_1[0][1].shape) # (0, 5)
- def test_create_original_predictions_from_object_prediction_list_without_mask_output(self,):
+ def test_create_original_predictions_from_object_prediction_list_without_mask_output(
+ self,
+ ):
from sahi.model import MmdetDetectionModel
# init model
diff --git a/tests/test_predict.py b/tests/test_predict.py
index bc675fff0..bc13c70cc 100644
--- a/tests/test_predict.py
+++ b/tests/test_predict.py
@@ -6,6 +6,7 @@
import unittest
import numpy as np
+
from sahi.utils.cv import read_image
@@ -24,10 +25,7 @@ def test_object_prediction(self):
def test_get_prediction_mmdet(self):
from sahi.model import MmdetDetectionModel
from sahi.predict import get_prediction
- from sahi.utils.mmdet import (
- MmdetTestConstants,
- download_mmdet_cascade_mask_rcnn_model,
- )
+ from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model
# init model
download_mmdet_cascade_mask_rcnn_model()
@@ -47,7 +45,10 @@ def test_get_prediction_mmdet(self):
# get full sized prediction
prediction_result = get_prediction(
- image=image, detection_model=mmdet_detection_model, shift_amount=[0, 0], full_shape=None,
+ image=image,
+ detection_model=mmdet_detection_model,
+ shift_amount=[0, 0],
+ full_shape=None,
)
object_prediction_list = prediction_result.object_prediction_list
@@ -72,10 +73,7 @@ def test_get_prediction_mmdet(self):
def test_get_prediction_yolov5(self):
from sahi.model import Yolov5DetectionModel
from sahi.predict import get_prediction
- from sahi.utils.yolov5 import (
- Yolov5TestConstants,
- download_yolov5s6_model,
- )
+ from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
# init model
download_yolov5s6_model()
@@ -120,10 +118,7 @@ def test_get_prediction_yolov5(self):
def test_get_sliced_prediction_mmdet(self):
from sahi.model import MmdetDetectionModel
from sahi.predict import get_sliced_prediction
- from sahi.utils.mmdet import (
- MmdetTestConstants,
- download_mmdet_cascade_mask_rcnn_model,
- )
+ from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model
# init model
download_mmdet_cascade_mask_rcnn_model()
@@ -187,11 +182,7 @@ def test_get_sliced_prediction_mmdet(self):
def test_get_sliced_prediction_yolov5(self):
from sahi.model import Yolov5DetectionModel
from sahi.predict import get_sliced_prediction
-
- from sahi.utils.yolov5 import (
- Yolov5TestConstants,
- download_yolov5s6_model,
- )
+ from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
# init model
download_yolov5s6_model()
@@ -253,14 +244,8 @@ def test_get_sliced_prediction_yolov5(self):
def test_coco_json_prediction(self):
from sahi.predict import predict
- from sahi.utils.yolov5 import (
- Yolov5TestConstants,
- download_yolov5s6_model,
- )
- from sahi.utils.mmdet import (
- MmdetTestConstants,
- download_mmdet_cascade_mask_rcnn_model,
- )
+ from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model
+ from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
# init model
download_mmdet_cascade_mask_rcnn_model()
diff --git a/tests/test_shapelyutils.py b/tests/test_shapelyutils.py
index f05654f8c..2bb095e11 100644
--- a/tests/test_shapelyutils.py
+++ b/tests/test_shapelyutils.py
@@ -3,12 +3,7 @@
import unittest
-from sahi.utils.shapely import (
- MultiPolygon,
- ShapelyAnnotation,
- get_shapely_box,
- get_shapely_multipolygon,
-)
+from sahi.utils.shapely import MultiPolygon, ShapelyAnnotation, get_shapely_box, get_shapely_multipolygon
class TestShapelyUtils(unittest.TestCase):
@@ -25,7 +20,8 @@ def test_get_shapely_multipolygon(self):
shapely_multipolygon = get_shapely_multipolygon(coco_segmentation)
self.assertListEqual(
- shapely_multipolygon[0].exterior.coords.xy[0].tolist(), [1.0, 325, 250, 5, 1],
+ shapely_multipolygon[0].exterior.coords.xy[0].tolist(),
+ [1.0, 325, 250, 5, 1],
)
self.assertEqual(shapely_multipolygon.area, 41177.5)
self.assertTupleEqual(shapely_multipolygon.bounds, (1, 1, 325, 200))
@@ -39,27 +35,41 @@ def test_shapely_annotation(self):
# test conversion methods
coco_segmentation = shapely_annotation.to_coco_segmentation()
self.assertEqual(
- coco_segmentation, [[1, 1, 325, 125, 250, 200, 5, 200]],
+ coco_segmentation,
+ [[1, 1, 325, 125, 250, 200, 5, 200]],
)
opencv_contours = shapely_annotation.to_opencv_contours()
self.assertEqual(
- opencv_contours, [[[[1, 1]], [[325, 125]], [[250, 200]], [[5, 200]], [[1, 1]],]],
+ opencv_contours,
+ [
+ [
+ [[1, 1]],
+ [[325, 125]],
+ [[250, 200]],
+ [[5, 200]],
+ [[1, 1]],
+ ]
+ ],
)
coco_bbox = shapely_annotation.to_coco_bbox()
self.assertEqual(
- coco_bbox, [1, 1, 324, 199],
+ coco_bbox,
+ [1, 1, 324, 199],
)
voc_bbox = shapely_annotation.to_voc_bbox()
self.assertEqual(
- voc_bbox, [1, 1, 325, 200],
+ voc_bbox,
+ [1, 1, 325, 200],
)
# test properties
self.assertEqual(
- shapely_annotation.area, int(shapely_multipolygon.area),
+ shapely_annotation.area,
+ int(shapely_multipolygon.area),
)
self.assertEqual(
- shapely_annotation.multipolygon, shapely_multipolygon,
+ shapely_annotation.multipolygon,
+ shapely_multipolygon,
)
# init shapely_annotation from coco bbox
@@ -70,27 +80,41 @@ def test_shapely_annotation(self):
# test conversion methods
coco_segmentation = shapely_annotation.to_coco_segmentation()
self.assertEqual(
- coco_segmentation, [[101, 1, 101, 101, 1, 101, 1, 1]],
+ coco_segmentation,
+ [[101, 1, 101, 101, 1, 101, 1, 1]],
)
opencv_contours = shapely_annotation.to_opencv_contours()
self.assertEqual(
- opencv_contours, [[[[101, 1]], [[101, 101]], [[1, 101]], [[1, 1]], [[101, 1]],]],
+ opencv_contours,
+ [
+ [
+ [[101, 1]],
+ [[101, 101]],
+ [[1, 101]],
+ [[1, 1]],
+ [[101, 1]],
+ ]
+ ],
)
coco_bbox = shapely_annotation.to_coco_bbox()
self.assertEqual(
- coco_bbox, [1, 1, 100, 100],
+ coco_bbox,
+ [1, 1, 100, 100],
)
voc_bbox = shapely_annotation.to_voc_bbox()
self.assertEqual(
- voc_bbox, [1, 1, 101, 101],
+ voc_bbox,
+ [1, 1, 101, 101],
)
# test properties
self.assertEqual(
- shapely_annotation.area, MultiPolygon([shapely_polygon]).area,
+ shapely_annotation.area,
+ MultiPolygon([shapely_polygon]).area,
)
self.assertEqual(
- shapely_annotation.multipolygon, MultiPolygon([shapely_polygon]),
+ shapely_annotation.multipolygon,
+ MultiPolygon([shapely_polygon]),
)
def test_get_intersection(self):
@@ -109,7 +133,21 @@ def test_get_intersection(self):
self.assertEqual(int(test_list[i][j]), int(true_list[i][j]))
self.assertEqual(
- intersection_shapely_annotation.to_coco_segmentation(), [[256, 97, 0, 0, 4, 199, 249, 199, 256, 192,]],
+ intersection_shapely_annotation.to_coco_segmentation(),
+ [
+ [
+ 256,
+ 97,
+ 0,
+ 0,
+ 4,
+ 199,
+ 249,
+ 199,
+ 256,
+ 192,
+ ]
+ ],
)
self.assertEqual(intersection_shapely_annotation.to_coco_bbox(), [0, 0, 256, 199])
diff --git a/tests/test_slicing.py b/tests/test_slicing.py
index 8a7e6ccd0..02f9b840f 100644
--- a/tests/test_slicing.py
+++ b/tests/test_slicing.py
@@ -5,6 +5,7 @@
import numpy as np
from PIL import Image
+
from sahi.slicing import slice_coco, slice_image
from sahi.utils.coco import Coco
from sahi.utils.cv import read_image
@@ -38,7 +39,8 @@ def test_slice_image(self):
self.assertEqual(slice_image_result.coco_images[0].annotations, [])
self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296)
self.assertEqual(
- slice_image_result.coco_images[15].annotations[1].bbox, [17, 186, 48, 152],
+ slice_image_result.coco_images[15].annotations[1].bbox,
+ [17, 186, 48, 152],
)
image_cv = read_image(image_path)
@@ -61,7 +63,8 @@ def test_slice_image(self):
self.assertEqual(slice_image_result.coco_images[0].annotations, [])
self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296)
self.assertEqual(
- slice_image_result.coco_images[15].annotations[1].bbox, [17, 186, 48, 152],
+ slice_image_result.coco_images[15].annotations[1].bbox,
+ [17, 186, 48, 152],
)
image_pil = Image.open(image_path)
@@ -84,7 +87,8 @@ def test_slice_image(self):
self.assertEqual(slice_image_result.coco_images[0].annotations, [])
self.assertEqual(slice_image_result.coco_images[15].annotations[1].area, 7296)
self.assertEqual(
- slice_image_result.coco_images[15].annotations[1].bbox, [17, 186, 48, 152],
+ slice_image_result.coco_images[15].annotations[1].bbox,
+ [17, 186, 48, 152],
)
def test_slice_coco(self):
@@ -119,7 +123,8 @@ def test_slice_coco(self):
self.assertEqual(coco_dict["annotations"][2]["category_id"], 1)
self.assertEqual(coco_dict["annotations"][2]["area"], 12483)
self.assertEqual(
- coco_dict["annotations"][2]["bbox"], [340, 204, 73, 171],
+ coco_dict["annotations"][2]["bbox"],
+ [340, 204, 73, 171],
)
shutil.rmtree(output_dir)
@@ -153,7 +158,8 @@ def test_slice_coco(self):
self.assertEqual(coco_dict["annotations"][2]["category_id"], 1)
self.assertEqual(coco_dict["annotations"][2]["area"], 12483)
self.assertEqual(
- coco_dict["annotations"][2]["bbox"], [340, 204, 73, 171],
+ coco_dict["annotations"][2]["bbox"],
+ [340, 204, 73, 171],
)
shutil.rmtree(output_dir)
diff --git a/tests/test_yolov5model.py b/tests/test_yolov5model.py
index aa111d5c7..33020c8dc 100644
--- a/tests/test_yolov5model.py
+++ b/tests/test_yolov5model.py
@@ -4,12 +4,9 @@
import unittest
import numpy as np
-from sahi.utils.cv import read_image
-from sahi.utils.yolov5 import (
- Yolov5TestConstants,
- download_yolov5s6_model,
-)
+from sahi.utils.cv import read_image
+from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
class TestYolov5DetectionModel(unittest.TestCase):
@@ -103,10 +100,13 @@ def test_convert_original_predictions(self):
self.assertEqual(object_prediction_list[5].category.id, 2)
self.assertEqual(object_prediction_list[5].category.name, "car")
self.assertEqual(
- object_prediction_list[5].bbox.to_coco_bbox(), [617, 195, 24, 23],
+ object_prediction_list[5].bbox.to_coco_bbox(),
+ [617, 195, 24, 23],
)
- def test_create_original_predictions_from_object_prediction_list(self,):
+ def test_create_original_predictions_from_object_prediction_list(
+ self,
+ ):
pass
# TODO: implement object_prediction_list to yolov5 format conversion