diff --git a/fastmot/detector.py b/fastmot/detector.py index a2d98712..b5b44621 100644 --- a/fastmot/detector.py +++ b/fastmot/detector.py @@ -86,7 +86,7 @@ def __init__(self, size, self.max_area = max_area class_ids = [] if class_ids is None else list(class_ids) - self.label_mask = np.zeros(len(models.LABEL_MAP), dtype=np.bool_) + self.label_mask = np.zeros(self.model.NUM_CLASSES, dtype=np.bool_) self.label_mask[class_ids] = True self.batch_size = int(np.prod(self.tiling_grid)) diff --git a/fastmot/models/ssd.py b/fastmot/models/ssd.py index a53454bd..ef9ea9f1 100644 --- a/fastmot/models/ssd.py +++ b/fastmot/models/ssd.py @@ -12,6 +12,7 @@ class SSD: PLUGIN_PATH = None ENGINE_PATH = None MODEL_PATH = None + NUM_CLASSES = None INPUT_SHAPE = None OUTPUT_NAME = None @@ -70,6 +71,7 @@ def build_engine(cls, trt_logger, batch_size, calib_dataset=Path.home() / 'VOCde class SSDMobileNetV1(SSD): ENGINE_PATH = Path(__file__).parent / 'ssd_mobilenet_v1_coco.trt' MODEL_PATH = Path(__file__).parent / 'ssd_mobilenet_v1_coco.pb' + NUM_CLASSES = 90 INPUT_SHAPE = (3, 300, 300) OUTPUT_NAME = 'NMS' NMS_THRESH = 0.5 @@ -167,6 +169,7 @@ def add_plugin(cls, graph): class SSDMobileNetV2(SSD): ENGINE_PATH = Path(__file__).parent / 'ssd_mobilenet_v2_coco.trt' MODEL_PATH = Path(__file__).parent / 'ssd_mobilenet_v2_coco.pb' + NUM_CLASSES = 90 INPUT_SHAPE = (3, 300, 300) OUTPUT_NAME = 'NMS' NMS_THRESH = 0.5 @@ -263,6 +266,7 @@ def add_plugin(cls, graph): class SSDInceptionV2(SSD): ENGINE_PATH = Path(__file__).parent / 'ssd_inception_v2_coco.trt' MODEL_PATH = Path(__file__).parent / 'ssd_inception_v2_coco.pb' + NUM_CLASSES = 90 INPUT_SHAPE = (3, 300, 300) OUTPUT_NAME = 'NMS' NMS_THRESH = 0.5