Skip to content

Commit

Permalink
ONNX model inference (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
yarden-yagil-sony authored Mar 5, 2025
1 parent 66d2950 commit 6973a6b
Showing 1 changed file with 53 additions and 24 deletions.
77 changes: 53 additions & 24 deletions tutorials/pytorch/multiclass_nms_custom_layer_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
"!pip install -q torch\n",
"!pip install -q onnx\n",
"!pip install -q model_compression_toolkit\n",
"!pip install -q sony-custom-layers"
"!pip install -q sony-custom-layers\n",
"!pip install -q onnxruntime_extensions\n",
"!pip install -q onnxruntime"
],
"outputs": [],
"execution_count": null
Expand Down Expand Up @@ -78,7 +80,7 @@
"metadata": {},
"source": [
"class ObjectDetector(nn.Module):\n",
" def __init__(self, num_classes=2, max_detections=20):\n",
" def __init__(self, num_classes=2, max_detections=300):\n",
" super().__init__()\n",
" self.max_detections = max_detections\n",
"\n",
Expand All @@ -103,8 +105,7 @@
" bbox = self.bbox_reg(features)\n",
" bbox = bbox.view(batch_size, self.max_detections, 4, H_prime * W_prime).mean(dim=3)\n",
" class_probs = self.class_reg(features).view(batch_size, self.max_detections, -1, H_prime * W_prime)\n",
" class_probs = F.softmax(class_probs.mean(dim=2), dim=2)\n",
"\n",
" class_probs = F.softmax(class_probs.mean(dim=3), dim=2)\n",
" return bbox, class_probs\n",
"\n",
"model = ObjectDetector()\n",
Expand All @@ -128,11 +129,7 @@
{
"cell_type": "code",
"id": "72d25144f573ead3",
"metadata": {
"jupyter": {
"is_executing": true
}
},
"metadata": {},
"source": [
"NUM_ITERS = 20\n",
"BATCH_SIZE = 32\n",
Expand Down Expand Up @@ -175,18 +172,14 @@
{
"cell_type": "code",
"id": "baa386a04a8dd664",
"metadata": {
"jupyter": {
"is_executing": true
}
},
"metadata": {},
"source": [
"class PostProcessWrapper(nn.Module):\n",
" def __init__(self,\n",
" model: nn.Module,\n",
" score_threshold: float = 0.001,\n",
" iou_threshold: float = 0.7,\n",
" max_detections: int = 300):\n",
" max_detections: int = 20):\n",
"\n",
" super(PostProcessWrapper, self).__init__()\n",
" self.model = model\n",
Expand All @@ -208,7 +201,7 @@
"quant_model_with_nms = PostProcessWrapper(model=quant_model,\n",
" score_threshold=0.001,\n",
" iou_threshold=0.7,\n",
" max_detections=300).to(device=device)\n",
" max_detections=20).to(device=device)\n",
"print('Quantized model with NMS is ready')"
],
"outputs": [],
Expand All @@ -227,31 +220,67 @@
{
"cell_type": "code",
"id": "776a6f99bd0a6efe",
"metadata": {
"jupyter": {
"is_executing": true
}
},
"metadata": {},
"source": [
"model_path = './qmodel_with_nms.onnx'\n",
"mct.exporter.pytorch_export_model(model=quant_model_with_nms,\n",
" save_model_path='./qmodel_with_nms.onnx',\n",
" save_model_path=model_path,\n",
" repr_dataset=representative_data_generator)"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"id": "bb7c13e41a012f3",
"source": [
"### Model Inference\n",
"\n",
"In order to run model inference over our saved onnx model, we need to load the necessary custom operations using `load_custom_ops()` and create an onnxruntime inference session with these custom operations.\n"
],
"id": "40c2925dcd4f7901"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"import onnxruntime as ort\n",
"from sony_custom_layers.pytorch import load_custom_ops\n",
"import numpy as np\n",
"\n",
"random_input = np.random.rand(*(1, 3, 64, 64)).astype(np.float32)\n",
"\n",
"so = load_custom_ops()\n",
"session = ort.InferenceSession(model_path, sess_options=so)\n",
"input_name = session.get_inputs()[0].name\n",
"output_names = [output.name for output in session.get_outputs()]\n",
"preds = session.run(output_names, {input_name: random_input})\n",
"\n",
"\"\"\"\n",
"One can access prediction items as follows:\n",
"boxes = preds[0]\n",
"scores = preds[1]\n",
"labels = preds[2]\n",
"n_valid = preds[3]\n",
"\"\"\"\n",
"pass"
],
"id": "45408190fb8210fb",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
"\n",
"http://www.apache.org/licenses/LICENSE-2.0\n",
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
]
],
"id": "bb7c13e41a012f3"
}
],
"metadata": {
Expand Down

0 comments on commit 6973a6b

Please sign in to comment.