-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDetector.py
86 lines (65 loc) · 3.94 KB
/
Detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2 import model_zoo
import torch
import cv2
import numpy as np
class Detector:
def __init__(self, model_type="OD"):
self.cfg = get_cfg()
self.model_type = model_type
# Load model config and pretrained model
if model_type=="OD": # object detection
self.cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"))
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml")
elif model_type=="IS": # instance segmentation
self.cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
elif model_type=="KP": # keypoint detection
self.cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"))
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
elif model_type=="LVIS": # lvis segmentation
self.cfg.merge_from_file(model_zoo.get_config_file("LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml"))
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x.yaml")
elif model_type=="PS": # panoptic segmentation
self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
self.cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
self.predictor = DefaultPredictor(self.cfg)
def onImage(self, imagePath):
image = cv2.imread(imagePath)
if self.model_type != "PS":
predictions = self.predictor(image)
viz = Visualizer(image[:, :, ::-1], metadata = MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]),
instance_mode=ColorMode.SEGMENTATION)
output = viz.draw_instance_predictions(predictions["instances"].to("cpu"))
else:
predictions, segmentInfo = self.predictor(image)["panoptic_seg"]
viz = Visualizer(image[:, :, ::-1], MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]))
output = viz.draw_panoptic_seg_predictions(predictions.to("cpu"), segmentInfo)
cv2.imshow("Result", output.get_image()[:,:,::-1])
cv2.waitKey(0)
def onVideo(self, videoPath):
cap = cv2.VideoCapture(videoPath)
if (cap.isOpened()==False):
print("Error opening video stream or file")
return
(success, image) = cap.read()
while success:
if self.model_type != "PS":
predictions = self.predictor(image)
viz = Visualizer(image[:, :, ::-1], metadata = MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]),
instance_mode=ColorMode.SEGMENTATION)
output = viz.draw_instance_predictions(predictions["instances"].to("cpu"))
else:
predictions, segmentInfo = self.predictor(image)["panoptic_seg"]
viz = Visualizer(image[:, :, ::-1], MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]))
output = viz.draw_panoptic_seg_predictions(predictions.to("cpu"), segmentInfo)
cv2.imshow("Result", output.get_image()[:,:,::-1])
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
(success, image) = cap.read()