Skip to content

Commit

Permalink
[RKNPU2]support rknpu2 ClasModel #957 (#964)
Browse files Browse the repository at this point in the history
* [RKNPU2]support rknpu2 ClasModel #957

* [RKNPU2]support rknpu2 ClasModel #957

* [RKNPU2]support rknpu2 add Resnet50_vd example  #957

* [RKNPU2]support rknpu2 add Resnet50_vd example  #957

* [RKNPU2]support rknpu2, improve doc  #957

* [RKNPU2]support rknpu2, improve doc  #957

* [RKNPU2]support rknpu2, improve doc  #957

* [RKNPU2]support rknpu2, improve doc  #957

* [RKNPU2]support rknpu2, improve doc  #957

* [RKNPU2]support rknpu2, improve doc  #957

* [RKNPU2]support rknpu2, improve doc  #957
  • Loading branch information
pengwei1024 authored Dec 28, 2022
1 parent 02425bf commit 973c746
Show file tree
Hide file tree
Showing 13 changed files with 390 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/cn/faq/rknpu2/rknpu2.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ONNX模型不能直接调用RK芯片中的NPU进行运算,需要把ONNX模型
| Segmentation | PP-HumanSegV2Lite | portrait | 133/43 |
| Segmentation | PP-HumanSegV2Lite | human | 133/43 |
| Face Detection | SCRFD | SCRFD-2.5G-kps-640 | 108/42 |
| Classification | ResNet | ResNet50_vd | -/92 |

## RKNPU2 Backend推理使用教程

Expand Down
57 changes: 57 additions & 0 deletions examples/vision/classification/paddleclas/rknpu2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# PaddleClas 模型RKNPU2部署

## 转换模型
下面以 ResNet50_vd为例子,教大家如何转换分类模型到RKNN模型。

```bash
# 安装 paddle2onnx
pip install paddle2onnx

# 下载ResNet50_vd模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
tar -xvf ResNet50_vd_infer.tgz

# 静态图转ONNX模型,注意,这里的save_file请和压缩包名对齐
paddle2onnx --model_dir ResNet50_vd_infer \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--save_file ResNet50_vd_infer/ResNet50_vd_infer.onnx \
--enable_dev_version True \
--opset_version 12 \
--enable_onnx_checker True

# 固定shape,注意这里的inputs得对应netron.app展示的 inputs 的 name,有可能是image 或者 x
python -m paddle2onnx.optimize --input_model ResNet50_vd_infer/ResNet50_vd_infer.onnx \
--output_model ResNet50_vd_infer/ResNet50_vd_infer.onnx \
--input_shape_dict "{'inputs':[1,3,224,224]}"
```

### 编写模型导出配置文件
以转化RK3588的RKNN模型为例子,我们需要编辑tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml,来转换ONNX模型到RKNN模型。

默认的 mean=0, std=1是在内存做normalize,如果你需要在NPU上执行normalize操作,请根据你的模型配置normalize参数,例如:
```yaml
model_path: ./ResNet50_vd_infer.onnx
output_folder: ./
target_platform: RK3588
normalize:
mean: [[0.485,0.456,0.406]]
std: [[0.229,0.224,0.225]]
outputs: []
outputs_nodes: []
do_quantization: False
dataset:
```
# ONNX模型转RKNN模型
```shell
python tools/rknpu2/export.py \
--config_path tools/rknpu2/config/ResNet50_vd_infer_rknn.yaml \
--target_platform rk3588
```

## 其他链接
- [Cpp部署](./cpp)
- [Python部署](./python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
CMAKE_MINIMUM_REQUIRED(VERSION 3.10)
project(rknpu_test)

set(CMAKE_CXX_STANDARD 14)

# 指定下载解压后的fastdeploy库路径
set(FASTDEPLOY_INSTALL_DIR "thirdpartys/fastdeploy-0.0.3")

include(${FASTDEPLOY_INSTALL_DIR}/FastDeployConfig.cmake)
include_directories(${FastDeploy_INCLUDE_DIRS})
add_executable(rknpu_test infer.cc)
target_link_libraries(rknpu_test
${FastDeploy_LIBS}
)


set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/install)

install(TARGETS rknpu_test DESTINATION ./)

install(DIRECTORY ppclas_model_dir DESTINATION ./)
install(DIRECTORY images DESTINATION ./)

file(GLOB FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/*)
message("${FASTDEPLOY_LIBS}")
install(PROGRAMS ${FASTDEPLOY_LIBS} DESTINATION lib)

file(GLOB ONNXRUNTIME_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/onnxruntime/lib/*)
install(PROGRAMS ${ONNXRUNTIME_LIBS} DESTINATION lib)

install(DIRECTORY ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/opencv/lib DESTINATION ./)

file(GLOB PADDLETOONNX_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/paddle2onnx/lib/*)
install(PROGRAMS ${PADDLETOONNX_LIBS} DESTINATION lib)

file(GLOB RKNPU2_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/rknpu2_runtime/RK3588/lib/*)
install(PROGRAMS ${RKNPU2_LIBS} DESTINATION lib)
78 changes: 78 additions & 0 deletions examples/vision/classification/paddleclas/rknpu2/cpp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# PaddleClas C++部署示例

本目录下用于展示 ResNet50_vd 模型在RKNPU2上的部署,以下的部署过程以 ResNet50_vd 为例子。

在部署前,需确认以下两个步骤:

1. 软硬件环境满足要求
2. 根据开发环境,下载预编译部署库或者从头编译FastDeploy仓库

以上步骤请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)实现

## 生成基本目录文件

该例程由以下几个部分组成
```text
.
├── CMakeLists.txt
├── build # 编译文件夹
├── images # 存放图片的文件夹
├── infer.cc
├── ppclas_model_dir # 存放模型文件的文件夹
└── thirdpartys # 存放sdk的文件夹
```

首先需要先生成目录结构
```bash
mkdir build
mkdir images
mkdir ppclas_model_dir
mkdir thirdpartys
```

## 编译

### 编译并拷贝SDK到thirdpartys文件夹

请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)仓库编译SDK,编译完成后,将在build目录下生成
fastdeploy-0.0.3目录,请移动它至thirdpartys目录下.

### 拷贝模型文件,以及配置文件至model文件夹
在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中,将生成ONNX文件以及对应的yaml配置文件,请将配置文件存放到model文件夹内。
转换为RKNN后的模型文件也需要拷贝至model,转换方案: ([ResNet50_vd RKNN模型](../README.md))。

### 准备测试图片至image文件夹
```bash
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
```

### 编译example

```bash
cd build
cmake ..
make -j8
make install
```

## 运行例程

```bash
cd ./build/install
./rknpu_test ./ppclas_model_dir ./images/ILSVRC2012_val_00000010.jpeg
```

## 运行结果展示
ClassifyResult(
label_ids: 153,
scores: 0.684570,
)

## 注意事项
RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时,
DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用数据格式的转换。

## 其它文档
- [ResNet50_vd Python 部署](../python)
- [模型预测结果说明](../../../../../../docs/api/vision_results/)
- [转换ResNet50_vd RKNN模型文档](../README.md)
58 changes: 58 additions & 0 deletions examples/vision/classification/paddleclas/rknpu2/cpp/infer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/vision.h"

void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + "/ResNet50_vd_infer_rk3588.rknn";
auto params_file = "";
auto config_file = model_dir + "/inference_cls.yaml";

auto option = fastdeploy::RuntimeOption();
option.UseRKNPU2();

auto format = fastdeploy::ModelFormat::RKNN;

auto model = fastdeploy::vision::classification::PaddleClasModel(
model_file, params_file, config_file,option,format);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
model.GetPreprocessor().DisablePermute();
fastdeploy::TimeCounter tc;
tc.Start();
auto im = cv::imread(image_file);
fastdeploy::vision::ClassifyResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// print res
std::cout << res.Str() << std::endl;
tc.End();
tc.PrintInfo("PPClas in RKNPU2");
}

int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout
<< "Usage: rknpu_test path/to/model_dir path/to/image run_option, "
"e.g ./rknpu_test ./ppclas_model_dir ./images/ILSVRC2012_val_00000010.jpeg"
<< std::endl;
return -1;
}
RKNPU2Infer(argv[1], argv[2]);
return 0;
}
35 changes: 35 additions & 0 deletions examples/vision/classification/paddleclas/rknpu2/python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# PaddleClas Python部署示例

在部署前,需确认以下两个步骤

- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/rknpu2.md)

本目录下提供`infer.py`快速完成 ResNet50_vd 在RKNPU上部署的示例。执行如下脚本即可完成

```bash
# 下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/classification/paddleclas/rknpu2/python

# 下载图片
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg

# 推理
python3 infer.py --model_file ./ResNet50_vd_infer/ResNet50_vd_infer_rk3588.rknn --config_file ResNet50_vd_infer/inference_cls.yaml --image ILSVRC2012_val_00000010.jpeg

# 运行完成后返回结果如下所示
ClassifyResult(
label_ids: 153,
scores: 0.684570,
)
```


## 注意事项
RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时,
DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用数据格式的转换。

## 其它文档
- [ResNet50_vd C++部署](../cpp)
- [模型预测结果说明](../../../../../../docs/api/vision_results/)
- [转换ResNet50_vd RKNN模型文档](../README.md)
50 changes: 50 additions & 0 deletions examples/vision/classification/paddleclas/rknpu2/python/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os


def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file", required=True, help="Path of rknn model.")
parser.add_argument("--config_file", required=True, help="Path of config.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
return parser.parse_args()


if __name__ == "__main__":
args = parse_arguments()

model_file = args.model_file
params_file = ""
config_file = args.config_file
# 配置runtime,加载模型
runtime_option = fd.RuntimeOption()
runtime_option.use_rknpu2()
model = fd.vision.classification.ResNet50vd(
model_file,
params_file,
config_file,
runtime_option=runtime_option,
model_format=fd.ModelFormat.RKNN)
# 禁用通道转换
model.preprocessor.disable_permute()
im = cv2.imread(args.image)
result = model.predict(im, topk=1)
print(result)
3 changes: 2 additions & 1 deletion fastdeploy/vision/classification/ppcls/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
valid_ascend_backends = {Backend::LITE};
valid_kunlunxin_backends = {Backend::LITE};
valid_ipu_backends = {Backend::PDINFER};
} else if (model_format == ModelFormat::ONNX) {
} else {
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
valid_gpu_backends = {Backend::ORT, Backend::TRT};
valid_rknpu_backends = {Backend::RKNPU2};
}

runtime_option = custom_option;
Expand Down
6 changes: 6 additions & 0 deletions fastdeploy/vision/classification/ppcls/ppcls_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ void BindPaddleClas(pybind11::module& m) {
})
.def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) {
self.UseGpu(gpu_id);
})
.def("disable_normalize", [](vision::classification::PaddleClasPreprocessor& self) {
self.DisableNormalize();
})
.def("disable_permute", [](vision::classification::PaddleClasPreprocessor& self) {
self.DisablePermute();
});

pybind11::class_<vision::classification::PaddleClasPostprocessor>(
Expand Down
Loading

0 comments on commit 973c746

Please sign in to comment.