Skip to content

Commit

Permalink
CodeCamp2023-555 (#2469)
Browse files Browse the repository at this point in the history
* support condinst from mmdet

* remove

* update

* update

* support batch inference

* add condinst head unit testing

* fix lint error

* remove

* fix bug in postprocess

* remove

* update

---------

Co-authored-by: RunningLeon <[email protected]>
  • Loading branch information
Boomerl and RunningLeon authored Oct 8, 2023
1 parent e74901f commit 4c376d9
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 4 deletions.
2 changes: 1 addition & 1 deletion csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class ResizeInstanceMask : public ResizeBBox {
int resize_width = int(mask_width / scale_factor_[1] + 0.5);
// skip resize if scale_factor is 1.0
if (resize_height != mask_height || resize_width != mask_width) {
cv::resize(mask_mat, mask_mat, cv::Size(resize_height, resize_width), cv::INTER_LINEAR);
cv::resize(mask_mat, mask_mat, cv::Size(resize_width, resize_height), cv::INTER_LINEAR);
}
// crop masks
mask_mat = mask_mat(cv::Range(0, img_h), cv::Range(0, img_w)).clone();
Expand Down
5 changes: 5 additions & 0 deletions mmdeploy/codebase/mmdet/deploy/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
type = 'ResizeInstanceMask' # for instance-seg
# resize and crop mask to origin image
params['is_resize_mask'] = True
if 'mask_thr' in params:
type = 'ResizeInstanceMask' # for instance-seg
# resize and crop mask to origin image
params['mask_thr_binary'] = params['mask_thr']
params['is_resize_mask'] = True

if get_backend(self.deploy_cfg) == Backend.RKNN:
if 'YOLO' in self.model_cfg.model.type or \
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def postprocessing_results(self,
masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
if model_type == 'RTMDet':
if model_type in ['RTMDet', 'CondInst']:
export_postprocess_mask = True
else:
export_postprocess_mask = False
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import base_dense_head # noqa: F401,F403
from . import centernet_head # noqa: F401,F403
from . import condinst_head # noqa: F401,F403
from . import detr_head # noqa: F401,F403
from . import fovea_head # noqa: F401,F403
from . import gfl_head # noqa: F401,F403
Expand Down
202 changes: 202 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch
from mmdet.models.utils import aligned_bilinear
from mmengine.config import ConfigDict
from torch import Tensor

from mmdeploy.codebase.mmdet.deploy import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmcv.ops.nms import multiclass_nms


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstBboxHead.predict_by_feat')
def condinst_bbox_head__predict_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
param_preds: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True,
):
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg

assert len(cls_scores) == len(bbox_preds)
device = bbox_preds[0].device
cfg = self.test_cfg if cfg is None else cfg
batch_size = bbox_preds[0].shape[0]
featmap_sizes = [cls_score.shape[-2:] for cls_score in cls_scores]

all_level_points_strides = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
all_level_points = [i[:, :2] for i in all_level_points_strides]
all_level_strides = [i[:, 2] for i in all_level_points_strides]

flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
for bbox_pred in bbox_preds
]
flatten_score_factors = [
score_factor.permute(0, 2, 3, 1).reshape(batch_size, -1, 1)
for score_factor in score_factors
]
flatten_param_preds = [
param_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_params)
for param_pred in param_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_score_factors = torch.cat(flatten_score_factors, dim=1).sigmoid()
flatten_param_preds = torch.cat(flatten_param_preds, dim=1)

points = torch.cat(all_level_points)
strides = torch.cat(all_level_strides)
tl_x = points[..., 0] - flatten_bbox_preds[..., 0]
tl_y = points[..., 1] - flatten_bbox_preds[..., 1]
br_x = points[..., 0] + flatten_bbox_preds[..., 2]
br_y = points[..., 1] + flatten_bbox_preds[..., 3]

bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
scores = flatten_cls_scores
score_factors = flatten_score_factors
param_preds = flatten_param_preds
scores = scores * score_factors

# get post processing config
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)

dets, labels, inds = multiclass_nms(
bboxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k,
output_index=True,
)

batch_inds = torch.arange(batch_size, device=bboxes.device).view(-1, 1)
points = points.unsqueeze(0).repeat(batch_size, 1, 1)
strides = strides.unsqueeze(0).repeat(batch_size, 1)
param_preds = param_preds[batch_inds, inds, :]
points = points[batch_inds, inds, :]
strides = strides[batch_inds, inds]
results = dict(
dets=dets,
labels=labels,
param_preds=param_preds,
points=points,
strides=strides)
return results


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstMaskHead.forward')
def condinst_mask_head__forward(self, x: tuple,
positive_infos: Dict[str, torch.Tensor]):
mask_feats = self.mask_feature_head(x)

param_preds = positive_infos['param_preds']
points = positive_infos['points']
strides = positive_infos['strides']

batch_size = points.shape[0]
num_insts = points.shape[1]
hw = mask_feats.size()[-2:]
mask_feats = mask_feats.unsqueeze(1).repeat(1, num_insts, 1, 1, 1)

points = points.reshape(-1, 1, 2).unsqueeze(0)
locations = self.prior_generator.single_level_grid_priors(
hw, level_idx=0, device=mask_feats.device)
locations = locations.unsqueeze(0).repeat(batch_size, 1,
1).reshape(batch_size, 1, -1, 2)
centers = points.reshape(batch_size, -1, 1, 2)
rel_coordinates = (centers - locations).permute(0, 1, 3, 2).float()
rel_coordinates /= (strides[:, :, None, None] * self.size_of_interest)
rel_coords = rel_coordinates.reshape(batch_size, -1, 2, hw[0], hw[1])
mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=2)

weights, biases = _parse_dynamic_params(self, param_preds)
mask_preds = _dynamic_conv_forward(mask_head_inputs, weights, biases)
mask_preds = mask_preds.reshape(batch_size, num_insts, hw[0], hw[1])
mask_preds = aligned_bilinear(
mask_preds, int(self.mask_feat_stride / self.mask_out_stride))
return (mask_preds, )


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstMaskHead.predict_by_feat')
def condinst_mask_head__predict_by_feat(self,
mask_preds: Tensor,
results_list: Dict[str, torch.Tensor],
batch_img_metas: List[dict],
rescale: bool = True,
**kwargs):
cfg = self.test_cfg

dets = results_list['dets']
labels = results_list['labels']
img_hw = batch_img_metas[0]['img_shape'][:2]

mask_preds = mask_preds.sigmoid()
mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride)
mask_preds = mask_preds[:, :, :img_hw[0], :img_hw[1]]
masks = (mask_preds > cfg.mask_thr).float()

return dets, labels, masks


def _parse_dynamic_params(self, params: Tensor):
"""parse the dynamic params for dynamic conv."""
batch_size = params.shape[0]
num_insts = params.shape[1]
params = params.permute(1, 0, 2)
params_splits = list(
torch.split_with_sizes(
params, self.weight_nums + self.bias_nums, dim=2))

weight_splits = params_splits[:self.num_layers]
bias_splits = params_splits[self.num_layers:]

for idx in range(self.num_layers):
if idx < self.num_layers - 1:
weight_splits[idx] = weight_splits[idx].reshape(
batch_size, num_insts, self.in_channels, -1)
else:
weight_splits[idx] = weight_splits[idx].reshape(
batch_size, num_insts, 1, -1)

return weight_splits, bias_splits


def _dynamic_conv_forward(features: Tensor, weights: List[Tensor],
biases: List[Tensor]):
"""dynamic forward, each layer follow a relu."""
n_layers = len(weights)
x = features.flatten(0, 1).flatten(2)
for i, (w, b) in enumerate(zip(weights, biases)):
# replace dynamic conv with bmm
w = w.flatten(0, 1)
b = b.flatten(0, 1).unsqueeze(2)
x = torch.bmm(w, x)
x = x + b
if i < n_layers - 1:
x = x.clamp_(min=0)
return x
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,30 @@
'instance_segmentor_forward',
inputs=['input'],
outputs=['dets', 'labels', 'masks'])
def __forward_impl_instance_seg(self, batch_inputs, data_samples, **kwargs):
def __forward_impl_instance_seg(self,
batch_inputs,
data_samples,
rescale=True,
**kwargs):
"""Rewrite and adding mark for `forward`.
Encapsulate this function for rewriting `forward` of BaseDetector.
1. Add mark for BaseDetector.
2. Support both dynamic and static export to onnx.
"""
x = self.extract_feat(batch_inputs)
mask_outs = self.mask_head.predict(x, data_samples, rescale=False)
if self.with_bbox:
# the bbox branch does not need to be scaled to the original
# image scale, because the mask branch will scale both bbox
# and mask at the same time.
bbox_rescale = rescale if not self.with_mask else False
results_list = self.bbox_head.predict(
x, data_samples, rescale=bbox_rescale)
else:
results_list = None

mask_outs = self.mask_head.predict(
x, data_samples, rescale=rescale, results_list=results_list)
return mask_outs


Expand Down
3 changes: 3 additions & 0 deletions mmdeploy/pytorch/functions/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def tensor__repeat__tensorrt(input: torch.Tensor, *size: Union[torch.Size,

origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
if isinstance(*size, tuple):
return origin_func(input.unsqueeze(0),
*([1] + list(*size))).squeeze(0)
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
else:
return origin_func(input, *size)
10 changes: 10 additions & 0 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,13 @@ models:
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32

- name: CondInst
metafile: configs/condinst/metafile.yml
model_configs:
- configs/condinst/condinst_r50_fpn_ms-poly-90k_coco_instance.py
pipelines:
- deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime_dynamic.py
backend_test: *default_backend_test
- deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py
backend_test: *default_backend_test

0 comments on commit 4c376d9

Please sign in to comment.