Skip to content

Commit

Permalink
Fix RT-DETR exported onnx model (ultralytics#3317)
Browse files Browse the repository at this point in the history
Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 24, 2023
1 parent 2f58b58 commit 51d8cfa
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ultralytics/nn/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device=
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.inf)
anchors = anchors.masked_fill(~valid_mask, float('inf'))
return anchors, valid_mask

def _get_encoder_input(self, x):
Expand All @@ -294,7 +294,7 @@ def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
bs = len(feats)
# prepare input for decoder
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256
features = self.enc_output(valid_mask * feats) # bs, h*w, 256

enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
# dynamic anchors + static content
Expand Down

0 comments on commit 51d8cfa

Please sign in to comment.