Skip to content

Commit

Permalink
Formatting codebase with black (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
sinanonur authored Jul 28, 2021
1 parent 4df3ec6 commit b523d0a
Show file tree
Hide file tree
Showing 41 changed files with 320 additions and 996 deletions.
72 changes: 13 additions & 59 deletions sahi/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,7 @@ def __repr__(self):
class Mask:
@classmethod
def from_float_mask(
cls,
mask,
full_shape=None,
mask_threshold: float = 0.5,
shift_amount: list = [0, 0],
cls, mask, full_shape=None, mask_threshold: float = 0.5, shift_amount: list = [0, 0],
):
"""
Args:
Expand All @@ -130,18 +126,11 @@ def from_float_mask(
Size of the full image after shifting, should be in the form of [height, width]
"""
bool_mask = mask > mask_threshold
return cls(
bool_mask=bool_mask,
shift_amount=shift_amount,
full_shape=full_shape,
)
return cls(bool_mask=bool_mask, shift_amount=shift_amount, full_shape=full_shape,)

@classmethod
def from_coco_segmentation(
cls,
segmentation,
full_shape=None,
shift_amount: list = [0, 0],
cls, segmentation, full_shape=None, shift_amount: list = [0, 0],
):
"""
Init Mask from coco segmentation representation.
Expand All @@ -163,17 +152,10 @@ def from_coco_segmentation(
assert full_shape is not None, "full_shape must be provided"

bool_mask = get_bool_mask_from_coco_segmentation(segmentation, height=full_shape[0], width=full_shape[1])
return cls(
bool_mask=bool_mask,
shift_amount=shift_amount,
full_shape=full_shape,
)
return cls(bool_mask=bool_mask, shift_amount=shift_amount, full_shape=full_shape,)

def __init__(
self,
bool_mask=None,
full_shape=None,
shift_amount: list = [0, 0],
self, bool_mask=None, full_shape=None, shift_amount: list = [0, 0],
):
"""
Args:
Expand Down Expand Up @@ -234,14 +216,7 @@ def get_shifted_mask(self):
# Confirm full_shape is specified
assert (self.full_shape_height is not None) and (self.full_shape_width is not None), "full_shape is None"
# init full mask
mask_fullsized = np.full(
(
self.full_shape_height,
self.full_shape_width,
),
0,
dtype="float32",
)
mask_fullsized = np.full((self.full_shape_height, self.full_shape_width,), 0, dtype="float32",)

# arrange starting ending indexes
starting_pixel = [self.shift_x, self.shift_y]
Expand All @@ -255,11 +230,7 @@ def get_shifted_mask(self):
: ending_pixel[1] - starting_pixel[1], : ending_pixel[0] - starting_pixel[0]
]

return Mask(
mask_fullsized,
shift_amount=[0, 0],
full_shape=self.full_shape,
)
return Mask(mask_fullsized, shift_amount=[0, 0], full_shape=self.full_shape,)

def to_coco_segmentation(self):
"""
Expand Down Expand Up @@ -470,10 +441,7 @@ def from_shapely_annotation(

@classmethod
def from_imantics_annotation(
cls,
annotation,
shift_amount: Optional[List[int]] = [0, 0],
full_shape: Optional[List[int]] = None,
cls, annotation, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
):
"""
Creates ObjectAnnotation from imantics.annotation.Annotation
Expand Down Expand Up @@ -527,18 +495,11 @@ def __init__(
self.mask = None
self.bbox = BoundingBox(bbox, shift_amount)
else:
self.mask = Mask(
bool_mask=bool_mask,
shift_amount=shift_amount,
full_shape=full_shape,
)
self.mask = Mask(bool_mask=bool_mask, shift_amount=shift_amount, full_shape=full_shape,)
bbox = get_bbox_from_bool_mask(bool_mask)
self.bbox = BoundingBox(bbox, shift_amount)
category_name = category_name if category_name else str(category_id)
self.category = Category(
id=category_id,
name=category_name,
)
self.category = Category(id=category_id, name=category_name,)

self.merged = None

Expand All @@ -554,9 +515,7 @@ def to_coco_annotation(self):
)
else:
coco_annotation = CocoAnnotation.from_coco_bbox(
bbox=self.bbox.to_coco_bbox(),
category_id=self.category.id,
category_name=self.category.name,
bbox=self.bbox.to_coco_bbox(), category_id=self.category.id, category_name=self.category.name,
)
return coco_annotation

Expand All @@ -573,10 +532,7 @@ def to_coco_prediction(self):
)
else:
coco_prediction = CocoPrediction.from_coco_bbox(
bbox=self.bbox.to_coco_bbox(),
category_id=self.category.id,
category_name=self.category.name,
score=1,
bbox=self.bbox.to_coco_bbox(), category_id=self.category.id, category_name=self.category.name, score=1,
)
return coco_prediction

Expand All @@ -589,9 +545,7 @@ def to_shapely_annotation(self):
segmentation=self.mask.to_coco_segmentation(),
)
else:
shapely_annotation = ShapelyAnnotation.from_coco_bbox(
bbox=self.bbox.to_coco_bbox(),
)
shapely_annotation = ShapelyAnnotation.from_coco_bbox(bbox=self.bbox.to_coco_bbox(),)
return shapely_annotation

def to_imantics_annotation(self):
Expand Down
35 changes: 9 additions & 26 deletions sahi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def perform_inference(self, image: np.ndarray, image_size: int = None):
NotImplementedError()

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount: Optional[List[int]] = [0, 0],
full_shape: Optional[List[int]] = None,
self, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
):
"""
This function should be implemented in a way that self._original_predictions should
Expand Down Expand Up @@ -119,9 +117,7 @@ def _apply_category_remapping(self):
object_prediction.category.id = new_category_id_int

def convert_original_predictions(
self,
shift_amount: Optional[List[int]] = [0, 0],
full_shape: Optional[List[int]] = None,
self, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
):
"""
Converts original predictions of the detection model to a list of
Expand All @@ -134,8 +130,7 @@ def convert_original_predictions(
Size of the full image after shifting, should be in the form of [height, width]
"""
self._create_object_prediction_list_from_original_predictions(
shift_amount=shift_amount,
full_shape=full_shape,
shift_amount=shift_amount, full_shape=full_shape,
)
if self.category_remapping:
self._apply_category_remapping()
Expand All @@ -148,9 +143,7 @@ def object_prediction_list(self):
def original_predictions(self):
return self._original_predictions

def _create_predictions_from_object_prediction_list(
object_prediction_list: List[ObjectPrediction],
):
def _create_predictions_from_object_prediction_list(object_prediction_list: List[ObjectPrediction],):
"""
This function should be implemented in a way that it converts a list of
prediction.ObjectPrediction instance to detection model's original prediction format.
Expand Down Expand Up @@ -179,11 +172,7 @@ def load_model(self):
from mmdet.apis import init_detector

# set model
model = init_detector(
config=self.config_path,
checkpoint=self.model_path,
device=self.device,
)
model = init_detector(config=self.config_path, checkpoint=self.model_path, device=self.device,)
self.model = model

# set category_mapping
Expand Down Expand Up @@ -250,9 +239,7 @@ def category_names(self):
return self.model.CLASSES

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount: Optional[List[int]] = [0, 0],
full_shape: Optional[List[int]] = None,
self, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
Expand Down Expand Up @@ -307,8 +294,7 @@ def _create_object_prediction_list_from_original_predictions(
self._object_prediction_list = object_prediction_list

def _create_original_predictions_from_object_prediction_list(
self,
object_prediction_list: List[ObjectPrediction],
self, object_prediction_list: List[ObjectPrediction],
):
"""
Converts a list of prediction.ObjectPrediction instance to detection model's original prediction format.
Expand Down Expand Up @@ -426,9 +412,7 @@ def category_names(self):
return self.model.names

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount: Optional[List[int]] = [0, 0],
full_shape: Optional[List[int]] = None,
self, shift_amount: Optional[List[int]] = [0, 0], full_shape: Optional[List[int]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
Expand Down Expand Up @@ -472,8 +456,7 @@ def _create_object_prediction_list_from_original_predictions(
self._object_prediction_list = object_prediction_list

def _create_original_predictions_from_object_prediction_list(
self,
object_prediction_list: List[ObjectPrediction],
self, object_prediction_list: List[ObjectPrediction],
):
"""
Converts a list of prediction.ObjectPrediction instance to detection model's original
Expand Down
28 changes: 6 additions & 22 deletions sahi/postprocess/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ class PostprocessPredictions:
"""Combines predictions using NMS elimination utilizing provided match metric ('IOU' or 'IOS')"""

def __init__(
self,
match_threshold: float = 0.5,
match_metric: str = "IOU",
class_agnostic: bool = True,
self, match_threshold: float = 0.5, match_metric: str = "IOU", class_agnostic: bool = True,
):
self.match_threshold = match_threshold
self.class_agnostic = class_agnostic
Expand Down Expand Up @@ -100,8 +97,7 @@ def __call__(self):

class NMSPostprocess(PostprocessPredictions):
def __call__(
self,
object_predictions: List[ObjectPrediction],
self, object_predictions: List[ObjectPrediction],
):
source_object_predictions: List[ObjectPrediction] = copy.deepcopy(object_predictions)
selected_object_predictions: List[ObjectPrediction] = []
Expand All @@ -124,8 +120,7 @@ def __call__(

class UnionMergePostprocess(PostprocessPredictions):
def __call__(
self,
object_predictions: List[ObjectPrediction],
self, object_predictions: List[ObjectPrediction],
):
source_object_predictions: List[ObjectPrediction] = copy.deepcopy(object_predictions)
selected_object_predictions: List[ObjectPrediction] = []
Expand All @@ -150,11 +145,7 @@ def __call__(
selected_object_predictions.append(selected_object_prediction)
return selected_object_predictions

def _merge_object_prediction_pair(
self,
pred1: ObjectPrediction,
pred2: ObjectPrediction,
) -> ObjectPrediction:
def _merge_object_prediction_pair(self, pred1: ObjectPrediction, pred2: ObjectPrediction,) -> ObjectPrediction:
shift_amount = pred1.bbox.shift_amount
merged_bbox: BoundingBox = self._get_merged_bbox(pred1, pred2)
merged_score: float = self._get_merged_score(pred1, pred2)
Expand Down Expand Up @@ -191,10 +182,7 @@ def _get_merged_bbox(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Boundi
return bbox

@staticmethod
def _get_merged_score(
pred1: ObjectPrediction,
pred2: ObjectPrediction,
) -> float:
def _get_merged_score(pred1: ObjectPrediction, pred2: ObjectPrediction,) -> float:
scores: List[float] = [pred.score.value for pred in (pred1, pred2)]
return max(scores)

Expand All @@ -203,8 +191,4 @@ def _get_merged_mask(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Mask:
mask1 = pred1.mask
mask2 = pred2.mask
union_mask = np.logical_or(mask1.bool_mask, mask2.bool_mask)
return Mask(
bool_mask=union_mask,
full_shape=mask1.full_shape,
shift_amount=mask1.shift_amount,
)
return Mask(bool_mask=union_mask, full_shape=mask1.full_shape, shift_amount=mask1.shift_amount,)
4 changes: 1 addition & 3 deletions sahi/postprocess/legacy/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ class PredictionMatcher:
"""

def __init__(
self,
threshold: float = 0.5,
scorer: Callable[[BoxArray, BoxArray], float] = box_ios,
self, threshold: float = 0.5, scorer: Callable[[BoxArray, BoxArray], float] = box_ios,
):
self._threshold = threshold
self._scorer = scorer
Expand Down
21 changes: 4 additions & 17 deletions sahi/postprocess/legacy/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def merge_batch(
"""
if merge_type not in ["merge", "ensemble"]:
raise ValueError(
'Unknown merge type. Supported types are ["merge", "ensemble"], got type: ',
merge_type,
'Unknown merge type. Supported types are ["merge", "ensemble"], got type: ', merge_type,
)

unions = matcher.find_matched_predictions(predictions, ignore_class_label)
Expand Down Expand Up @@ -116,11 +115,7 @@ def _store_merging_info(count, prediction, merge_type):
else:
prediction.merged = False

def _merge_pair(
self,
pred1: ObjectPrediction,
pred2: ObjectPrediction,
) -> ObjectPrediction:
def _merge_pair(self, pred1: ObjectPrediction, pred2: ObjectPrediction,) -> ObjectPrediction:
box1 = extract_box(pred1)
box2 = extract_box(pred2)
merged_box = list(self._merge_box(box1, box2))
Expand Down Expand Up @@ -159,11 +154,7 @@ def _assert_equal_labels(pred1: ObjectPrediction, pred2: ObjectPrediction):
def _merge_box(self, box1: BoxArray, box2: BoxArray) -> BoxArray:
return self._box_merger(box1, box2)

def _merge_score(
self,
pred1: ObjectPrediction,
pred2: ObjectPrediction,
) -> float:
def _merge_score(self, pred1: ObjectPrediction, pred2: ObjectPrediction,) -> float:
scores = [pred.score.value for pred in (pred1, pred2)]
policy = self._score_merging_method
if policy == ScoreMergingPolicy.SMALLER_SCORE:
Expand All @@ -185,11 +176,7 @@ def _merge_mask(pred1: ObjectPrediction, pred2: ObjectPrediction) -> Mask:
mask1 = pred1.mask
mask2 = pred2.mask
union_mask = np.logical_or(mask1.bool_mask, mask2.bool_mask)
return Mask(
bool_mask=union_mask,
full_shape=mask1.full_shape,
shift_amount=mask1.shift_amount,
)
return Mask(bool_mask=union_mask, full_shape=mask1.full_shape, shift_amount=mask1.shift_amount,)

def _validate_box_merger(self, box_merger: Callable):
if box_merger.__name__ not in self.BOX_MERGERS:
Expand Down
Loading

0 comments on commit b523d0a

Please sign in to comment.