Skip to content

Commit be37fd9

Browse files
Reshape feature implementation
1 parent e93f0b0 commit be37fd9

File tree

10 files changed

+208
-8
lines changed

10 files changed

+208
-8
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <algorithm>
66
#include <cassert>
77
#include <fstream>
8+
#include <map>
89
#include <regex>
910
#include <sstream>
1011
#include <unordered_map>
@@ -61,12 +62,19 @@ BackendManager::BackendManager(SessionContext& session_context,
6162
return "";
6263
}(subgraph);
6364

64-
// Save the indexes of graph inputs among fused_node's inputDefs
65-
// (which also contains initializers).
65+
if (!session_context_.shape.empty()) {
66+
ValidateInputShapes(session_context_.shape, subgraph.GetInputs());
67+
}
68+
6669
for (uint32_t index = 0; const auto& node : subgraph.GetInputs()) {
70+
if(subgraph.GetGraph().GetConsumerNodes(node->Name()).size()==0)
71+
{
72+
continue;
73+
}
6774
subgraph_context_.input_names.insert({node->Name(), index++});
6875
}
6976

77+
7078
for (uint32_t index = 0; const auto& node : subgraph.GetOutputs()) {
7179
subgraph_context_.output_names.insert({node->Name(), index++});
7280
}
@@ -100,7 +108,7 @@ BackendManager::BackendManager(SessionContext& session_context,
100108
}
101109
}
102110

103-
if (ModelHasSymbolicInputDims(subgraph)) {
111+
if (ModelHasSymbolicInputDims(subgraph) && session_context_.shape.empty()) {
104112
subgraph_context_.has_dynamic_input_shape = true;
105113
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims";
106114
if ((session_context_.device_type.find("CPU") != std::string::npos ||
@@ -306,6 +314,39 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& s
306314
return has_sym_dims;
307315
}
308316

317+
void BackendManager::ValidateInputShapes(const std::map<std::string, ov::PartialShape>& shape,
318+
const std::vector<const NodeArg*>& graph_inputs) const {
319+
for (const auto& [tensor_name, requested_shape] : shape) {
320+
// Find matching input in graph
321+
const NodeArg* graph_input = nullptr;
322+
for (const auto* input : graph_inputs) {
323+
if (input->Name() == tensor_name) {
324+
graph_input = input;
325+
break;
326+
}
327+
}
328+
329+
if (!graph_input) {
330+
ORT_THROW("Input " + tensor_name + "specified in reshape_input does not exist");
331+
}
332+
333+
const ONNX_NAMESPACE::TensorShapeProto* graph_shape = graph_input->Shape();
334+
if (!graph_shape) {
335+
ORT_THROW("Graph input" + tensor_name + "has no shape information");
336+
}
337+
338+
// Check dimensions count matches
339+
size_t graph_dim_count = graph_shape->dim_size();
340+
size_t requested_dim_count = requested_shape.get_max_shape().size();
341+
if (graph_dim_count != requested_dim_count) {
342+
ORT_THROW("Dimensions mismatched for input" + tensor_name +
343+
": graph expects " + std::to_string(graph_dim_count) +
344+
" dimensions but reshape_input specifies " +
345+
std::to_string(requested_dim_count) + " dimensions");
346+
}
347+
}
348+
}
349+
309350
// Check to see if the graph is QDQ
310351
static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
311352
std::unordered_set<std::string> qdq_ops = {"QuantizeLinear", "DequantizeLinear"};

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class BackendManager {
3939

4040
bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const;
4141
bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const;
42+
void ValidateInputShapes(const shape_t& shape,
43+
const std::vector<const NodeArg*>& graph_inputs) const;
4244

4345
std::shared_ptr<ONNX_NAMESPACE::ModelProto>
4446
ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto);

onnxruntime/core/providers/openvino/backend_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ CreateOVModel(const std::string model,
146146
try {
147147
auto ov_model = OVCore::ReadModel(model, session_context.onnx_model_path_name.string());
148148

149+
if (!session_context.shape.empty()) {
150+
LOGS_DEFAULT(INFO) << log_tag << "Reshaping the ov tensor to specified shape";
151+
ov_model->reshape(session_context.shape);
152+
}
153+
149154
// Check for Constant Folding
150155
if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) {
151156
ov::pass::ConstantFolding pass_const_obj;

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <sstream>
88
#include <fstream>
99
#include <utility>
10+
#include <vector>
1011

1112
#include "core/providers/shared_library/provider_api.h"
1213
#include "core/providers/openvino/backend_utils.h"
@@ -96,6 +97,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
9697
} else if (!session_context_.has_external_weights &&
9798
!subgraph_context_.has_dynamic_input_shape &&
9899
!session_context_.so_context_enable &&
100+
session_context.shape.empty() &&
99101
auto_unified_compile) {
100102
// Unified OV compile_model is efficient when ov model caching is enabled
101103
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
@@ -418,9 +420,20 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
418420
(it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) {
419421
ov_tensor_data_t ov_tensor_data;
420422
const auto& input = ov_input_info.at(input_idx);
421-
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape(),
422-
const_cast<void*>(tensor.GetTensorRawData()));
423-
423+
if (!session_context_.shape.empty()) {
424+
ov::PartialShape partial_shape = input.get_partial_shape();
425+
const auto& ort_dims = tensor.GetTensorTypeAndShapeInfo().GetShape();
426+
ValidateOrtDimsAgainstPartialShape(ort_dims, partial_shape);
427+
ov::Shape concrete_shape;
428+
for (size_t i = 0; i < ort_dims.size(); ++i) {
429+
concrete_shape.push_back(ort_dims[i]);
430+
}
431+
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), concrete_shape,
432+
const_cast<void*>(tensor.GetTensorRawData()));
433+
} else {
434+
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape(),
435+
const_cast<void*>(tensor.GetTensorRawData()));
436+
}
424437
ov_tensor_data.ort_ptr = tensor.GetTensorRawData();
425438
ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data;
426439

@@ -434,6 +447,10 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
434447
}
435448
} // Loop subgraph original input names
436449

450+
if (!session_context_.shape.empty()) {
451+
infer_request->Infer();
452+
}
453+
437454
if (session_context_.device_type.find("NPU") != std::string::npos) {
438455
// Set the output blob as remote blob
439456
auto graph_output_info = exe_network_.Get().outputs();
@@ -465,8 +482,15 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
465482
ov_tensor_data_t ov_tensor_data;
466483
const auto& output = graph_output_info.at(output_idx);
467484
ov_tensor_data.ort_ptr = tensor.GetTensorRawData();
468-
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output.get_shape(),
469-
const_cast<void*>(tensor.GetTensorRawData()));
485+
486+
if (!session_context_.shape.empty()) {
487+
ov::Tensor output_tensor = infer_request->GetOutputTensor(output_idx);
488+
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output_tensor.get_shape(),
489+
const_cast<void*>(tensor.GetTensorRawData()));
490+
} else {
491+
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output.get_shape(),
492+
const_cast<void*>(tensor.GetTensorRawData()));
493+
}
470494
ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data;
471495

472496
try {
@@ -669,6 +693,26 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
669693
}
670694
}
671695

696+
void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector<int64_t>& ort_dims,
697+
const ov::PartialShape& partial_shape) const {
698+
// Check if the number of dimensions matches
699+
if (static_cast<int64_t>(ort_dims.size()) != partial_shape.rank().get_length()) {
700+
ORT_THROW("Mismatch in number of dimensions between ORT tensor and OpenVINO PartialShape.");
701+
}
702+
// Validate each dimension
703+
for (size_t i = 0; i < ort_dims.size(); ++i) {
704+
const auto& ov_dim = partial_shape[i]; // OpenVINO dimension at index i
705+
int64_t ort_dim = ort_dims[i]; // ORT dimension at index i
706+
707+
// Check if the ORT dimension is within the specified range
708+
int64_t min_dim = ov_dim.get_min_length();
709+
int64_t max_dim = ov_dim.get_max_length();
710+
if (ort_dim < min_dim || ort_dim > max_dim) {
711+
ORT_THROW(" ORT Dimension is out of range");
712+
}
713+
}
714+
}
715+
672716
void BasicBackend::Infer(OrtKernelContext* ctx) {
673717
// Preliminary Thread safety mechanism
674718
// currently allows a maximum of 8 Infer request's to parallel execute at the same time

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class BasicBackend : public IBackend {
5151
void EnableStreams();
5252
void SetNumThreads(ov::AnyMap& device_config);
5353
void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);
54+
void ValidateOrtDimsAgainstPartialShape(const std::vector<int64_t>& ort_dims,
55+
const ov::PartialShape& partial_shape) const;
5456

5557
#ifdef IO_BUFFER_ENABLED
5658
void StartRemoteAsyncInference(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);

onnxruntime/core/providers/openvino/contexts.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ struct SharedContext {
6161
};
6262

6363
using config_t = std::map<std::string, ov::AnyMap>;
64+
using shape_t = std::map<std::string, ov::PartialShape>;
6465

6566
struct ProviderInfo {
6667
std::string device_type{""}; // [device_type]: Overrides the accelerator hardware type and
@@ -74,6 +75,7 @@ struct ProviderInfo {
7475
uint32_t num_of_threads{0}; // [num_of_threads]: Overrides the accelerator default value of
7576
// number of threads with this value at runtime.
7677
config_t load_config{}; // JSON config map to load custom OV parameters.
78+
shape_t shape{}; // Used for reshaping ov tensors to a particular lower and upper bound
7779
fs::path cache_dir{""}; // [cache_dir]: specify the path to
7880
// dump and load the blobs for the model caching/kernel caching
7981
// (GPU) feature. If blob files are already present,

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,97 @@ struct OpenVINO_Provider : Provider {
212212

213213
pi.precision = ParsePrecision(provider_options, pi.device_type, "precision");
214214

215+
if (provider_options.contains("reshape_input") && pi.device_type == "NPU") {
216+
auto parse_input_shapes = [&](const std::string& reshape_input_definition) {
217+
std::map<std::string, ov::PartialShape> parsed_shape_map;
218+
std::string unparsed_definition = reshape_input_definition;
219+
220+
while (!unparsed_definition.empty()) {
221+
// Find the next shape definition brakcet
222+
auto shape_start_bracket = unparsed_definition.find_first_of('[');
223+
if (shape_start_bracket == std::string::npos) {
224+
ORT_THROW("Malformed input: missing opening bracket '[' in: " + unparsed_definition);
225+
}
226+
// Extract the tensor name
227+
std::string tensor_name = unparsed_definition.substr(0, shape_start_bracket);
228+
// Remove the leading/trailing whitespaces
229+
tensor_name.erase(0, tensor_name.find_first_not_of("\t"));
230+
tensor_name.erase(tensor_name.find_last_not_of("\t") + 1);
231+
232+
if (tensor_name.empty()) {
233+
ORT_THROW("Empty tensor name provided in rehsape_input parameter");
234+
}
235+
236+
// Closing bracket for current shape definition
237+
auto shape_end_bracket = unparsed_definition.find_first_of(']', shape_start_bracket);
238+
239+
if (shape_end_bracket == std::string::npos || shape_end_bracket < shape_start_bracket) {
240+
ORT_THROW("Missing closing bracket ']' for tensor: " + tensor_name);
241+
}
242+
243+
// Extract shape dimensions string
244+
std::string shape_dimension_str = unparsed_definition.substr(shape_start_bracket + 1,
245+
shape_end_bracket - shape_start_bracket - 1);
246+
std::vector<ov::Dimension> dimension_values;
247+
std::stringstream dimension_stream(shape_dimension_str);
248+
std::string dimension_token;
249+
250+
while (std::getline(dimension_stream, dimension_token, ',')) {
251+
// Remove leading/trailing whitespaces
252+
dimension_token.erase(0, dimension_token.find_first_not_of("\t"));
253+
dimension_token.erase(dimension_token.find_last_not_of("\t") + 1);
254+
255+
// Check if dimension is a range
256+
size_t range_separator_pos = dimension_token.find("..");
257+
if (range_separator_pos != std::string::npos) {
258+
std::string range_start_str = dimension_token.substr(0, range_separator_pos);
259+
std::string range_end_str = dimension_token.substr(range_separator_pos + 2);
260+
261+
// Remove leading/trailing spaced
262+
range_start_str.erase(0, range_start_str.find_first_not_of("\t"));
263+
range_start_str.erase(range_start_str.find_last_not_of("\t") + 1);
264+
range_end_str.erase(0, range_end_str.find_first_not_of("\t"));
265+
range_end_str.erase(range_end_str.find_last_not_of("\t") + 1);
266+
267+
if (range_start_str.empty() || range_end_str.empty() ||
268+
!std::all_of(range_start_str.begin(), range_start_str.end(), ::isdigit) ||
269+
!std::all_of(range_end_str.begin(), range_end_str.end(), ::isdigit)) {
270+
ORT_THROW("Invalid dimension range format: " + dimension_token + " for tensor: " + tensor_name);
271+
}
272+
273+
int range_start = std::stoi(range_start_str);
274+
int range_end = std::stoi(range_end_str);
275+
276+
if (range_start > range_end) {
277+
ORT_THROW("Invalid dimension range (start > end) for tensor: " + tensor_name);
278+
}
279+
280+
dimension_values.emplace_back(ov::Dimension(range_start, range_end));
281+
} else {
282+
// Handle single dimension value
283+
if (dimension_token.empty() ||
284+
!std::all_of(dimension_token.begin(), dimension_token.end(), ::isdigit)) {
285+
ORT_THROW("Invalid dimension value: " + dimension_token + " for tensor: " + tensor_name);
286+
}
287+
dimension_values.emplace_back(std::stoi(dimension_token));
288+
}
289+
}
290+
291+
// Store parsed shape in result map
292+
parsed_shape_map[tensor_name] = ov::PartialShape(dimension_values);
293+
// Update reminaing unparsed string
294+
unparsed_definition = unparsed_definition.substr(shape_end_bracket + 1);
295+
if (!unparsed_definition.empty() && unparsed_definition.front() == ',') {
296+
unparsed_definition = unparsed_definition.substr(1);
297+
}
298+
// Remove leading whitespaces
299+
unparsed_definition.erase(0, unparsed_definition.find_first_not_of("\t"));
300+
}
301+
return parsed_shape_map;
302+
};
303+
pi.shape = parse_input_shapes(provider_options.at("reshape_input"));
304+
}
305+
215306
if (provider_options.contains("load_config")) {
216307
auto parse_config = [&](const std::string& config_str) -> std::map<std::string, ov::AnyMap> {
217308
// If the config string is empty, return an empty map and skip processing

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) {
211211
}
212212
}
213213

214+
OVTensor OVInferRequest::GetOutputTensor(const int& output_idx) {
215+
try {
216+
return ovInfReq.get_output_tensor(output_idx);
217+
} catch (const Exception& e) {
218+
ORT_THROW(log_tag + " Cannot access output tensor: " + e.what());
219+
} catch (...) {
220+
ORT_THROW(log_tag + " Cannot access output tensor");
221+
}
222+
}
223+
214224
std::string OVInferRequest::GetInputTensorName(uint32_t index) {
215225
try {
216226
const auto& model = ovInfReq.get_compiled_model();

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class OVInferRequest {
9191
OVTensorPtr GetTensor(const std::string& name);
9292
std::string GetInputTensorName(uint32_t index);
9393
void SetTensor(const std::string& name, OVTensorPtr& blob);
94+
OVTensor GetOutputTensor(const int& output_idx);
9495
void StartAsync();
9596
void Infer();
9697
void WaitRequest();

onnxruntime/test/perftest/ort_test_session.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
788788
}
789789
} else if (key == "device_memory_name") {
790790
device_memory_name_ = std::move(value);
791+
} else if (key == "reshape_input") {
792+
ov_options[key] = value;
791793
} else {
792794
ORT_THROW(
793795
"[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO."

0 commit comments

Comments
 (0)