Skip to content

Commit

Permalink
refactor predict api (#170)
Browse files Browse the repository at this point in the history
* refactor predict api

* update notebooks
  • Loading branch information
fcakyon authored Jul 13, 2021
1 parent 9501cd8 commit 14b91ab
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 289 deletions.
276 changes: 162 additions & 114 deletions demo/inference_for_mmdetection.ipynb

Large diffs are not rendered by default.

158 changes: 98 additions & 60 deletions demo/inference_for_yolov5.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions sahi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(
config_path: Optional[str] = None,
device: Optional[str] = None,
mask_threshold: float = 0.5,
prediction_score_threshold: float = 0.3,
confidence_threshold: float = 0.3,
category_mapping: Optional[Dict] = None,
category_remapping: Optional[Dict] = None,
load_at_init: bool = True,
Expand All @@ -32,8 +32,8 @@ def __init__(
Torch device, "cpu" or "cuda"
mask_threshold: float
Value to threshold mask pixels, should be between 0 and 1
prediction_score_threshold: float
All predictions with score < prediction_score_threshold will be discarded
confidence_threshold: float
All predictions with score < confidence_threshold will be discarded
category_mapping: dict: str to str
Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
category_remapping: dict: str to int
Expand All @@ -46,7 +46,7 @@ def __init__(
self.model = None
self.device = device
self.mask_threshold = mask_threshold
self.prediction_score_threshold = prediction_score_threshold
self.confidence_threshold = confidence_threshold
self.category_mapping = category_mapping
self.category_remapping = category_remapping
self._original_predictions = None
Expand Down
113 changes: 66 additions & 47 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
save_pickle,
)

MODEL_TYPE_TO_MODEL_CLASS_NAME = {
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
}


def get_prediction(
image,
Expand Down Expand Up @@ -82,7 +87,7 @@ def get_prediction(
filtered_object_prediction_list = [
object_prediction
for object_prediction in object_prediction_list
if object_prediction.score.value > detection_model.prediction_score_threshold
if object_prediction.score.value > detection_model.confidence_threshold
]
# postprocess matching predictions
if postprocess is not None:
Expand Down Expand Up @@ -272,8 +277,13 @@ def get_sliced_prediction(


def predict(
model_name: str = "MmdetDetectionModel",
model_parameters: Dict = None,
model_type: str = "mmdet",
model_path: str = None,
model_config_path: str = None,
model_confidence_threshold: float = 0.25,
model_device: str = None,
model_category_mapping: dict = None,
model_category_remapping: dict = None,
source: str = None,
no_standard_prediction: bool = False,
no_sliced_prediction: bool = False,
Expand Down Expand Up @@ -301,19 +311,20 @@ def predict(
Performs prediction for all present images in given folder.
Args:
model_name: str
Name of the implemented DetectionModel in model.py file.
model_parameter: a dict with fields:
model_path: str
Path for the instance segmentation model weight
config_path: str
Path for the mmdetection instance segmentation model config file
prediction_score_threshold: float
All predictions with score < prediction_score_threshold will be discarded.
device: str
Torch device, "cpu" or "cuda"
category_remapping: dict: str to int
Remap category ids after performing inference
model_type: str
mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'.
model_path: str
Path for the model weight
model_config_path: str
Path for the detection model config file
model_confidence_threshold: float
All predictions with score < model_confidence_threshold will be discarded.
model_device: str
Torch device, "cpu" or "cuda"
model_category_mapping: dict
Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
model_category_remapping: dict: str to int
Remap category ids after performing inference
source: str
Folder directory that contains images or path of the image to be predicted.
no_standard_prediction: bool
Expand Down Expand Up @@ -398,14 +409,15 @@ def predict(

# init model instance
time_start = time.time()
DetectionModel = import_class(model_name)
model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
DetectionModel = import_class(model_class_name)
detection_model = DetectionModel(
model_path=model_parameters["model_path"],
config_path=model_parameters.get("config_path", None),
prediction_score_threshold=model_parameters.get("prediction_score_threshold", 0.25),
device=model_parameters.get("device", None),
category_mapping=model_parameters.get("category_mapping", None),
category_remapping=model_parameters.get("category_remapping", None),
model_path=model_path,
config_path=model_config_path,
confidence_threshold=model_confidence_threshold,
device=model_device,
category_mapping=model_category_mapping,
category_remapping=model_category_remapping,
load_at_init=False,
)
detection_model.load_model()
Expand Down Expand Up @@ -568,8 +580,13 @@ def predict(


def predict_fiftyone(
model_name: str = "MmdetDetectionModel",
model_parameters: Dict = None,
model_type: str = "mmdet",
model_path: str = None,
model_config_path: str = None,
model_confidence_threshold: float = 0.25,
model_device: str = None,
model_category_mapping: dict = None,
model_category_remapping: dict = None,
coco_json_path: str = None,
coco_image_dir: str = None,
no_standard_prediction: bool = False,
Expand All @@ -588,19 +605,20 @@ def predict_fiftyone(
Performs prediction for all present images in given folder.
Args:
model_name: str
Name of the implemented DetectionModel in model.py file.
model_parameter: a dict with fields:
model_path: str
Path for the instance segmentation model weight
config_path: str
Path for the mmdetection instance segmentation model config file
prediction_score_threshold: float
All predictions with score < prediction_score_threshold will be discarded.
device: str
Torch device, "cpu" or "cuda"
category_remapping: dict: str to int
Remap category ids after performing inference
model_type: str
mmdet for 'MmdetDetectionModel', 'yolov5' for 'Yolov5DetectionModel'.
model_path: str
Path for the model weight
model_config_path: str
Path for the detection model config file
model_confidence_threshold: float
All predictions with score < model_confidence_threshold will be discarded.
model_device: str
Torch device, "cpu" or "cuda"
model_category_mapping: dict
Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"}
model_category_remapping: dict: str to int
Remap category ids after performing inference
coco_json_path: str
If coco file path is provided, detection results will be exported in coco json format.
coco_image_dir: str
Expand Down Expand Up @@ -651,14 +669,15 @@ def predict_fiftyone(

# init model instance
time_start = time.time()
DetectionModel = import_class(model_name)
model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
DetectionModel = import_class(model_class_name)
detection_model = DetectionModel(
model_path=model_parameters["model_path"],
config_path=model_parameters.get("config_path", None),
prediction_score_threshold=model_parameters.get("prediction_score_threshold", 0.25),
device=model_parameters.get("device", None),
category_mapping=model_parameters.get("category_mapping", None),
category_remapping=model_parameters.get("category_remapping", None),
model_path=model_path,
config_path=model_config_path,
confidence_threshold=model_confidence_threshold,
device=model_device,
category_mapping=model_category_mapping,
category_remapping=model_category_remapping,
load_at_init=False,
)
detection_model.load_model()
Expand Down Expand Up @@ -702,7 +721,7 @@ def predict_fiftyone(
durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"]

# Save predictions to dataset
sample[model_name] = fo.Detections(detections=prediction_result.to_fiftyone_detections())
sample[model_type] = fo.Detections(detections=prediction_result.to_fiftyone_detections())
sample.save()

# print prediction duration
Expand All @@ -728,7 +747,7 @@ def predict_fiftyone(
session.dataset = dataset
# Evaluate the predictions
results = dataset.evaluate_detections(
model_name,
model_type,
gt_field="ground_truth",
eval_key="eval",
iou=postprocess_match_threshold,
Expand Down
22 changes: 7 additions & 15 deletions scripts/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,14 @@

opt = parser.parse_args()

model_type_to_model_name = {
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
}

model_parameters = {
"model_path": opt.model_path,
"config_path": opt.config_path,
"prediction_score_threshold": opt.conf_thresh,
"device": opt.device,
"category_mapping": opt.category_mapping,
"category_remapping": opt.category_remapping,
}
predict(
model_name=model_type_to_model_name[opt.model_type],
model_parameters=model_parameters,
model_type=opt.model_type,
model_path=opt.model_path,
model_config_path=opt.config_path,
model_confidence_threshold=opt.conf_thresh,
model_device=opt.device,
model_category_mapping=opt.category_mapping,
model_category_remapping=opt.category_remapping,
source=opt.source,
project=opt.project,
name=opt.name,
Expand Down
22 changes: 7 additions & 15 deletions scripts/predict_fiftyone.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,14 @@
)
opt = parser.parse_args()

model_type_to_model_name = {
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
}

model_parameters = {
"model_path": opt.model_path,
"config_path": opt.config_path,
"prediction_score_threshold": opt.conf_thresh,
"device": opt.device,
"category_mapping": opt.category_mapping,
"category_remapping": opt.category_remapping,
}
predict_fiftyone(
model_name=model_type_to_model_name[opt.model_type],
model_parameters=model_parameters,
model_type=opt.model_type,
model_path=opt.model_path,
model_config_path=opt.config_path,
model_confidence_threshold=opt.conf_thresh,
model_device=opt.device,
model_category_mapping=opt.category_mapping,
model_category_remapping=opt.category_remapping,
coco_json_path=opt.coco_json_path,
coco_image_dir=opt.coco_image_dir,
no_standard_prediction=opt.no_standard_pred,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_mmdetectionmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_load_model(self):
mmdet_detection_model = MmdetDetectionModel(
model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
prediction_score_threshold=0.3,
confidence_threshold=0.3,
device=None,
category_remapping=None,
load_at_init=True,
Expand All @@ -39,7 +39,7 @@ def test_perform_inference_with_mask_output(self):
mmdet_detection_model = MmdetDetectionModel(
model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
prediction_score_threshold=0.5,
confidence_threshold=0.5,
device=None,
category_remapping=None,
load_at_init=True,
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_perform_inference_without_mask_output(self):
mmdet_detection_model = MmdetDetectionModel(
model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH,
config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH,
prediction_score_threshold=0.5,
confidence_threshold=0.5,
device=None,
category_remapping=None,
load_at_init=True,
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_convert_original_predictions_with_mask_output(self):
mmdet_detection_model = MmdetDetectionModel(
model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
prediction_score_threshold=0.5,
confidence_threshold=0.5,
device=None,
category_remapping=None,
load_at_init=True,
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_convert_original_predictions_without_mask_output(self):
mmdet_detection_model = MmdetDetectionModel(
model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH,
config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH,
prediction_score_threshold=0.5,
confidence_threshold=0.5,
device=None,
category_remapping=None,
load_at_init=True,
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_create_original_predictions_from_object_prediction_list_with_mask_outpu
mmdet_detection_model = MmdetDetectionModel(
model_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH,
config_path=MmdetTestConstants.MMDET_CASCADEMASKRCNN_CONFIG_PATH,
prediction_score_threshold=0.5,
confidence_threshold=0.5,
device=None,
category_remapping=None,
load_at_init=True,
Expand Down Expand Up @@ -248,7 +248,7 @@ def test_create_original_predictions_from_object_prediction_list_without_mask_ou
mmdet_detection_model = MmdetDetectionModel(
model_path=MmdetTestConstants.MMDET_RETINANET_MODEL_PATH,
config_path=MmdetTestConstants.MMDET_RETINANET_CONFIG_PATH,
prediction_score_threshold=0.5,
confidence_threshold=0.5,
device=None,
category_remapping=None,
load_at_init=True,
Expand Down
Loading

0 comments on commit 14b91ab

Please sign in to comment.