Skip to content

Commit ca2bc91

Browse files
Reshape feature implementation
1 parent 60ee27a commit ca2bc91

File tree

10 files changed

+208
-8
lines changed

10 files changed

+208
-8
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <algorithm>
66
#include <cassert>
77
#include <fstream>
8+
#include <map>
89
#include <regex>
910
#include <sstream>
1011
#include <unordered_map>
@@ -61,12 +62,19 @@ BackendManager::BackendManager(SessionContext& session_context,
6162
return "";
6263
}(subgraph);
6364

64-
// Save the indexes of graph inputs among fused_node's inputDefs
65-
// (which also contains initializers).
65+
if (!session_context_.shape.empty()) {
66+
ValidateInputShapes(session_context_.shape, subgraph.GetInputs());
67+
}
68+
6669
for (uint32_t index = 0; const auto& node : subgraph.GetInputs()) {
70+
if(subgraph.GetGraph().GetConsumerNodes(node->Name()).size()==0)
71+
{
72+
continue;
73+
}
6774
subgraph_context_.input_names.insert({node->Name(), index++});
6875
}
6976

77+
7078
for (uint32_t index = 0; const auto& node : subgraph.GetOutputs()) {
7179
subgraph_context_.output_names.insert({node->Name(), index++});
7280
}
@@ -100,7 +108,7 @@ BackendManager::BackendManager(SessionContext& session_context,
100108
}
101109
}
102110

103-
if (ModelHasSymbolicInputDims(subgraph)) {
111+
if (ModelHasSymbolicInputDims(subgraph) && session_context_.shape.empty()) {
104112
subgraph_context_.has_dynamic_input_shape = true;
105113
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims";
106114
if ((session_context_.device_type.find("CPU") != std::string::npos ||
@@ -306,6 +314,39 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& s
306314
return has_sym_dims;
307315
}
308316

317+
void BackendManager::ValidateInputShapes(const std::map<std::string, ov::PartialShape>& shape,
318+
const std::vector<const NodeArg*>& graph_inputs) const {
319+
for (const auto& [tensor_name, requested_shape] : shape) {
320+
// Find matching input in graph
321+
const NodeArg* graph_input = nullptr;
322+
for (const auto* input : graph_inputs) {
323+
if (input->Name() == tensor_name) {
324+
graph_input = input;
325+
break;
326+
}
327+
}
328+
329+
if (!graph_input) {
330+
ORT_THROW("Input " + tensor_name + "specified in reshape_input does not exist");
331+
}
332+
333+
const ONNX_NAMESPACE::TensorShapeProto* graph_shape = graph_input->Shape();
334+
if (!graph_shape) {
335+
ORT_THROW("Graph input" + tensor_name + "has no shape information");
336+
}
337+
338+
// Check dimensions count matches
339+
size_t graph_dim_count = graph_shape->dim_size();
340+
size_t requested_dim_count = requested_shape.get_max_shape().size();
341+
if (graph_dim_count != requested_dim_count) {
342+
ORT_THROW("Dimensions mismatched for input" + tensor_name +
343+
": graph expects " + std::to_string(graph_dim_count) +
344+
" dimensions but reshape_input specifies " +
345+
std::to_string(requested_dim_count) + " dimensions");
346+
}
347+
}
348+
}
349+
309350
// Check to see if the graph is QDQ
310351
static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
311352
std::unordered_set<std::string> qdq_ops = {"QuantizeLinear", "DequantizeLinear"};

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class BackendManager {
3939

4040
bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const;
4141
bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const;
42+
void ValidateInputShapes(const shape_t& shape,
43+
const std::vector<const NodeArg*>& graph_inputs) const;
4244

4345
std::shared_ptr<ONNX_NAMESPACE::ModelProto>
4446
ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto);

onnxruntime/core/providers/openvino/backend_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ CreateOVModel(const std::string model,
146146
try {
147147
auto ov_model = OVCore::Get()->ReadModel(model, session_context.onnx_model_path_name.string());
148148

149+
if (!session_context.shape.empty()) {
150+
LOGS_DEFAULT(INFO) << log_tag << "Reshaping the ov tensor to specified shape";
151+
ov_model->reshape(session_context.shape);
152+
}
153+
149154
// Check for Constant Folding
150155
if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) {
151156
ov::pass::ConstantFolding pass_const_obj;

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <sstream>
88
#include <fstream>
99
#include <utility>
10+
#include <vector>
1011

1112
#include "core/providers/shared_library/provider_api.h"
1213
#include "core/providers/openvino/backend_utils.h"
@@ -96,6 +97,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
9697
} else if (!session_context_.has_external_weights &&
9798
!subgraph_context_.has_dynamic_input_shape &&
9899
!session_context_.so_context_enable &&
100+
session_context.shape.empty() &&
99101
auto_unified_compile) {
100102
// Unified OV compile_model is efficient when ov model caching is enabled
101103
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
@@ -418,9 +420,20 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
418420
(it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) {
419421
ov_tensor_data_t ov_tensor_data;
420422
const auto& input = ov_input_info.at(input_idx);
421-
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape(),
422-
const_cast<void*>(tensor.GetTensorRawData()));
423-
423+
if (!session_context_.shape.empty()) {
424+
ov::PartialShape partial_shape = input.get_partial_shape();
425+
const auto& ort_dims = tensor.GetTensorTypeAndShapeInfo().GetShape();
426+
ValidateOrtDimsAgainstPartialShape(ort_dims, partial_shape);
427+
ov::Shape concrete_shape;
428+
for (size_t i = 0; i < ort_dims.size(); ++i) {
429+
concrete_shape.push_back(ort_dims[i]);
430+
}
431+
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), concrete_shape,
432+
const_cast<void*>(tensor.GetTensorRawData()));
433+
} else {
434+
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape(),
435+
const_cast<void*>(tensor.GetTensorRawData()));
436+
}
424437
ov_tensor_data.ort_ptr = tensor.GetTensorRawData();
425438
ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data;
426439

@@ -434,6 +447,10 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
434447
}
435448
} // Loop subgraph original input names
436449

450+
if (!session_context_.shape.empty()) {
451+
infer_request->Infer();
452+
}
453+
437454
if (session_context_.device_type.find("NPU") != std::string::npos) {
438455
// Set the output blob as remote blob
439456
auto graph_output_info = exe_network_.Get().outputs();
@@ -465,8 +482,15 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
465482
ov_tensor_data_t ov_tensor_data;
466483
const auto& output = graph_output_info.at(output_idx);
467484
ov_tensor_data.ort_ptr = tensor.GetTensorRawData();
468-
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output.get_shape(),
469-
const_cast<void*>(tensor.GetTensorRawData()));
485+
486+
if (!session_context_.shape.empty()) {
487+
ov::Tensor output_tensor = infer_request->GetOutputTensor(output_idx);
488+
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output_tensor.get_shape(),
489+
const_cast<void*>(tensor.GetTensorRawData()));
490+
} else {
491+
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output.get_shape(),
492+
const_cast<void*>(tensor.GetTensorRawData()));
493+
}
470494
ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data;
471495

472496
try {
@@ -669,6 +693,26 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
669693
}
670694
}
671695

696+
void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector<int64_t>& ort_dims,
697+
const ov::PartialShape& partial_shape) const {
698+
// Check if the number of dimensions matches
699+
if (static_cast<int64_t>(ort_dims.size()) != partial_shape.rank().get_length()) {
700+
ORT_THROW("Mismatch in number of dimensions between ORT tensor and OpenVINO PartialShape.");
701+
}
702+
// Validate each dimension
703+
for (size_t i = 0; i < ort_dims.size(); ++i) {
704+
const auto& ov_dim = partial_shape[i]; // OpenVINO dimension at index i
705+
int64_t ort_dim = ort_dims[i]; // ORT dimension at index i
706+
707+
// Check if the ORT dimension is within the specified range
708+
int64_t min_dim = ov_dim.get_min_length();
709+
int64_t max_dim = ov_dim.get_max_length();
710+
if (ort_dim < min_dim || ort_dim > max_dim) {
711+
ORT_THROW(" ORT Dimension is out of range");
712+
}
713+
}
714+
}
715+
672716
void BasicBackend::Infer(OrtKernelContext* ctx) {
673717
// Preliminary Thread safety mechanism
674718
// currently allows a maximum of 8 Infer request's to parallel execute at the same time

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class BasicBackend : public IBackend {
5151
void EnableStreams();
5252
void SetNumThreads(ov::AnyMap& device_config);
5353
void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);
54+
void ValidateOrtDimsAgainstPartialShape(const std::vector<int64_t>& ort_dims,
55+
const ov::PartialShape& partial_shape) const;
5456

5557
#ifdef IO_BUFFER_ENABLED
5658
void StartRemoteAsyncInference(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);

onnxruntime/core/providers/openvino/contexts.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class SharedContext : public WeakSingleton<SharedContext> {
6666
};
6767

6868
using config_t = std::map<std::string, ov::AnyMap>;
69+
using shape_t = std::map<std::string, ov::PartialShape>;
6970

7071
struct ProviderInfo {
7172
std::string device_type{""}; // [device_type]: Overrides the accelerator hardware type and
@@ -79,6 +80,7 @@ struct ProviderInfo {
7980
uint32_t num_of_threads{0}; // [num_of_threads]: Overrides the accelerator default value of
8081
// number of threads with this value at runtime.
8182
config_t load_config{}; // JSON config map to load custom OV parameters.
83+
shape_t shape{}; // Used for reshaping ov tensors to a particular lower and upper bound
8284
fs::path cache_dir{""}; // [cache_dir]: specify the path to
8385
// dump and load the blobs for the model caching/kernel caching
8486
// (GPU) feature. If blob files are already present,

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,97 @@ struct OpenVINO_Provider : Provider {
236236

237237
pi.precision = ParsePrecision(provider_options, pi.device_type, "precision");
238238

239+
if (provider_options.contains("reshape_input") && pi.device_type == "NPU") {
240+
auto parse_input_shapes = [&](const std::string& reshape_input_definition) {
241+
std::map<std::string, ov::PartialShape> parsed_shape_map;
242+
std::string unparsed_definition = reshape_input_definition;
243+
244+
while (!unparsed_definition.empty()) {
245+
// Find the next shape definition brakcet
246+
auto shape_start_bracket = unparsed_definition.find_first_of('[');
247+
if (shape_start_bracket == std::string::npos) {
248+
ORT_THROW("Malformed input: missing opening bracket '[' in: " + unparsed_definition);
249+
}
250+
// Extract the tensor name
251+
std::string tensor_name = unparsed_definition.substr(0, shape_start_bracket);
252+
// Remove the leading/trailing whitespaces
253+
tensor_name.erase(0, tensor_name.find_first_not_of("\t"));
254+
tensor_name.erase(tensor_name.find_last_not_of("\t") + 1);
255+
256+
if (tensor_name.empty()) {
257+
ORT_THROW("Empty tensor name provided in rehsape_input parameter");
258+
}
259+
260+
// Closing bracket for current shape definition
261+
auto shape_end_bracket = unparsed_definition.find_first_of(']', shape_start_bracket);
262+
263+
if (shape_end_bracket == std::string::npos || shape_end_bracket < shape_start_bracket) {
264+
ORT_THROW("Missing closing bracket ']' for tensor: " + tensor_name);
265+
}
266+
267+
// Extract shape dimensions string
268+
std::string shape_dimension_str = unparsed_definition.substr(shape_start_bracket + 1,
269+
shape_end_bracket - shape_start_bracket - 1);
270+
std::vector<ov::Dimension> dimension_values;
271+
std::stringstream dimension_stream(shape_dimension_str);
272+
std::string dimension_token;
273+
274+
while (std::getline(dimension_stream, dimension_token, ',')) {
275+
// Remove leading/trailing whitespaces
276+
dimension_token.erase(0, dimension_token.find_first_not_of("\t"));
277+
dimension_token.erase(dimension_token.find_last_not_of("\t") + 1);
278+
279+
// Check if dimension is a range
280+
size_t range_separator_pos = dimension_token.find("..");
281+
if (range_separator_pos != std::string::npos) {
282+
std::string range_start_str = dimension_token.substr(0, range_separator_pos);
283+
std::string range_end_str = dimension_token.substr(range_separator_pos + 2);
284+
285+
// Remove leading/trailing spaced
286+
range_start_str.erase(0, range_start_str.find_first_not_of("\t"));
287+
range_start_str.erase(range_start_str.find_last_not_of("\t") + 1);
288+
range_end_str.erase(0, range_end_str.find_first_not_of("\t"));
289+
range_end_str.erase(range_end_str.find_last_not_of("\t") + 1);
290+
291+
if (range_start_str.empty() || range_end_str.empty() ||
292+
!std::all_of(range_start_str.begin(), range_start_str.end(), ::isdigit) ||
293+
!std::all_of(range_end_str.begin(), range_end_str.end(), ::isdigit)) {
294+
ORT_THROW("Invalid dimension range format: " + dimension_token + " for tensor: " + tensor_name);
295+
}
296+
297+
int range_start = std::stoi(range_start_str);
298+
int range_end = std::stoi(range_end_str);
299+
300+
if (range_start > range_end) {
301+
ORT_THROW("Invalid dimension range (start > end) for tensor: " + tensor_name);
302+
}
303+
304+
dimension_values.emplace_back(ov::Dimension(range_start, range_end));
305+
} else {
306+
// Handle single dimension value
307+
if (dimension_token.empty() ||
308+
!std::all_of(dimension_token.begin(), dimension_token.end(), ::isdigit)) {
309+
ORT_THROW("Invalid dimension value: " + dimension_token + " for tensor: " + tensor_name);
310+
}
311+
dimension_values.emplace_back(std::stoi(dimension_token));
312+
}
313+
}
314+
315+
// Store parsed shape in result map
316+
parsed_shape_map[tensor_name] = ov::PartialShape(dimension_values);
317+
// Update reminaing unparsed string
318+
unparsed_definition = unparsed_definition.substr(shape_end_bracket + 1);
319+
if (!unparsed_definition.empty() && unparsed_definition.front() == ',') {
320+
unparsed_definition = unparsed_definition.substr(1);
321+
}
322+
// Remove leading whitespaces
323+
unparsed_definition.erase(0, unparsed_definition.find_first_not_of("\t"));
324+
}
325+
return parsed_shape_map;
326+
};
327+
pi.shape = parse_input_shapes(provider_options.at("reshape_input"));
328+
}
329+
239330
if (provider_options.contains("load_config")) {
240331
auto parse_config = [&](const std::string& config_str) -> std::map<std::string, ov::AnyMap> {
241332
// If the config string is empty, return an empty map and skip processing

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) {
197197
}
198198
}
199199

200+
OVTensor OVInferRequest::GetOutputTensor(const int& output_idx) {
201+
try {
202+
return ovInfReq.get_output_tensor(output_idx);
203+
} catch (const Exception& e) {
204+
ORT_THROW(log_tag + " Cannot access output tensor: " + e.what());
205+
} catch (...) {
206+
ORT_THROW(log_tag + " Cannot access output tensor");
207+
}
208+
}
209+
200210
std::string OVInferRequest::GetInputTensorName(uint32_t index) {
201211
try {
202212
const auto& model = ovInfReq.get_compiled_model();

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class OVInferRequest {
115115
OVTensorPtr GetTensor(const std::string& name);
116116
std::string GetInputTensorName(uint32_t index);
117117
void SetTensor(const std::string& name, OVTensorPtr& blob);
118+
OVTensor GetOutputTensor(const int& output_idx);
118119
void StartAsync();
119120
void Infer();
120121
void WaitRequest();

onnxruntime/test/perftest/ort_test_session.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
790790
}
791791
} else if (key == "device_memory_name") {
792792
device_memory_name_ = std::move(value);
793+
} else if (key == "reshape_input") {
794+
ov_options[key] = value;
793795
} else {
794796
ORT_THROW(
795797
"[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO."

0 commit comments

Comments
 (0)