15
15
#include " core/providers/openvino/backends/basic_backend.h"
16
16
#include " core/providers/openvino/onnx_ctx_model_helper.h"
17
17
#include " core/providers/openvino/backend_manager.h"
18
+ #include " core/providers/openvino/ov_stateful_patch_utils.h"
18
19
19
20
namespace onnxruntime {
20
21
@@ -29,6 +30,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
29
30
ptr_stream_t & model_stream)
30
31
: session_context_{session_context}, subgraph_context_{subgraph_context}, shared_context_{shared_context} {
31
32
std::string& hw_target = session_context_.device_type ;
33
+ bool enable_causallm = session_context_.enable_causallm ;
32
34
33
35
if (ValidateSubgraph (const_outputs_map_))
34
36
return ;
@@ -43,7 +45,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
43
45
// Setting OpenCL queue throttling for GPU
44
46
EnableGPUThrottling (device_config);
45
47
46
- // Enable streams; default=1 unless ovverriden by user config
48
+ // Enable streams; default=1 unless overridden by user configuration
47
49
EnableStreams ();
48
50
49
51
// Set the inference_num_threads property of the CPU
@@ -95,7 +97,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
95
97
} else if (!session_context_.has_external_weights &&
96
98
!subgraph_context_.has_dynamic_input_shape &&
97
99
!session_context_.so_context_enable &&
98
- auto_unified_compile) {
100
+ !enable_causallm && auto_unified_compile) {
99
101
// Unified OV compile_model is efficient when ov model caching is enabled
100
102
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
101
103
// Inputs with static dimensions
@@ -115,7 +117,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
115
117
}
116
118
auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
117
119
exe_network_ = OVCore::Get ()->CompileModel (
118
- ov_model, hw_target, device_config, subgraph_context_.subgraph_name );
120
+ ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
119
121
}
120
122
#endif
121
123
LOGS_DEFAULT (INFO) << log_tag << " Loaded model to the plugin" ;
@@ -200,6 +202,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
200
202
if (!session_context_.load_config .empty ()) {
201
203
const std::map<std::string, ov::AnyMap>& target_config = session_context_.load_config ;
202
204
205
+ if ((session_context_.device_type .find (" NPU" ) != std::string::npos) && session_context_.enable_causallm ) {
206
+ if (target_config.find (" NPU" ) != target_config.end ()) {
207
+ auto npu_genai_config = target_config.at (" NPU" );
208
+ CausalLMConfig ().ApplyConfig (npu_genai_config, device_config);
209
+ } else {
210
+ LOGS_DEFAULT (WARNING) << " ORT GenAI CausalLMConfig Configuration not found." ;
211
+ }
212
+ }
213
+
203
214
if (session_context_.device_type .find (" NPU" ) != std::string::npos) {
204
215
auto npuw_config = target_config.at (" NPU" );
205
216
@@ -265,7 +276,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
265
276
auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options,
266
277
const std::vector<ov::PropertyName>& supported_properties) {
267
278
for (const auto & [key, value] : config_options) {
268
- if (key.find (" NPUW" ) != std::string::npos) {
279
+ if ((key.find (" NPUW" ) != std::string::npos) ||
280
+ ((device_config.find (key) != device_config.end ()) && session_context_.enable_causallm )) {
269
281
continue ;
270
282
}
271
283
if (is_supported_and_mutable (key, supported_properties)) {
@@ -358,6 +370,13 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
358
370
device_config.emplace (ov::inference_num_threads (session_context_.num_of_threads ));
359
371
}
360
372
373
+ void BasicBackend::RewindKVCache (size_t index) {
374
+ OVInferRequestPtr infer_request;
375
+ infer_request = inferRequestsQueue_->getIdleRequest ();
376
+ infer_request->RewindKVCache (index );
377
+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
378
+ }
379
+
361
380
// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on
362
381
// an Infer Request indexed by infer_req_idx
363
382
void BasicBackend::StartAsyncInference (Ort::KernelContext& context, OVInferRequestPtr infer_request) {
@@ -376,14 +395,22 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
376
395
}
377
396
index ++;
378
397
}
398
+
399
+ // For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors.
400
+ // However, these tensors are internally converted to a stateful representation, which removes them.
401
+ // To prevent runtime exceptions, we simply continue processing here.
402
+ if (input_name.empty () || input_name == " beam_idx" ) continue ;
403
+
379
404
ORT_ENFORCE (!input_name.empty (), log_tag,
380
405
" Input names mismatch between OpenVINO and ONNX. " , onnx_input_name,
381
406
" doesn't exist in the list of OpenVINO input tensor names" );
382
407
size_t batch_slice_idx = 0 ;
383
408
if (subgraph_context_.has_dynamic_input_shape &&
384
409
!session_context_.disable_dynamic_shapes &&
385
410
(session_context_.device_type .find (" CPU" ) != std::string::npos ||
386
- session_context_.device_type .find (" GPU" ) != std::string::npos)) {
411
+ session_context_.device_type .find (" GPU" ) != std::string::npos ||
412
+ (session_context_.device_type .find (" NPU" ) != std::string::npos &&
413
+ session_context_.enable_causallm ))) {
387
414
auto tensor = context.GetInput (subgraph_context_.input_names .at (input_name));
388
415
auto tensor_info = tensor.GetTensorTypeAndShapeInfo ();
389
416
auto tensor_shape = tensor_info.GetShape ();
@@ -445,7 +472,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
445
472
}
446
473
} // Loop subgraph original input names
447
474
448
- if (session_context_.device_type .find (" NPU" ) != std::string::npos) {
475
+ // For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path for NPU plugin as well.
476
+ if (session_context_.device_type .find (" NPU" ) != std::string::npos && !session_context_.enable_causallm ) {
449
477
// Set the output blob as remote blob
450
478
auto graph_output_info = exe_network_.Get ().outputs ();
451
479
auto output_idx = 0 ;
@@ -640,7 +668,9 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
640
668
" list of OpenVINO output tensor names" );
641
669
}
642
670
if ((session_context_.device_type .find (" CPU" ) != std::string::npos ||
643
- session_context_.device_type .find (" GPU" ) != std::string::npos)) {
671
+ session_context_.device_type .find (" GPU" ) != std::string::npos ||
672
+ (session_context_.device_type .find (" NPU" ) != std::string::npos &&
673
+ session_context_.enable_causallm ))) {
644
674
try {
645
675
graph_output_blob = infer_request->GetTensor (output_name);
646
676
} catch (const char * msg) {
@@ -719,25 +749,41 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
719
749
try {
720
750
StartRemoteAsyncInference (context, infer_request);
721
751
} catch (std::string const & msg) {
752
+ // If the inference fails (exception from ov::InferRequest::infer()),
753
+ // we need to put the infer_request back into the pool to avoid deadlocks
754
+ // and to allow the next inference request to proceed.
755
+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
722
756
ORT_THROW (msg);
723
757
}
724
758
} else {
725
759
try {
726
760
StartAsyncInference (context, infer_request);
727
761
} catch (std::string const & msg) {
762
+ // If the inference fails (exception from ov::InferRequest::infer()),
763
+ // we need to put the infer_request back into the pool to avoid deadlocks
764
+ // and to allow the next inference request to proceed.
765
+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
728
766
ORT_THROW (msg);
729
767
}
730
768
}
731
769
#else
732
770
try {
733
771
StartAsyncInference (context, infer_request);
734
772
} catch (const std::runtime_error& e) {
773
+ // If the inference fails (exception from ov::InferRequest::infer()),
774
+ // we need to put the infer_request back into the pool to avoid deadlocks
775
+ // and to allow the next inference request to proceed.
776
+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
735
777
ORT_THROW (log_tag + " Exception at StartAsyncInference: " + e.what ());
736
778
}
737
779
#endif
738
780
try {
739
781
CompleteAsyncInference (context, infer_request);
740
782
} catch (const std::runtime_error& e) {
783
+ // If the inference fails (exception from ov::InferRequest::infer()),
784
+ // we need to put the infer_request back into the pool to avoid deadlocks
785
+ // and to allow the next inference request to proceed.
786
+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
741
787
ORT_THROW (log_tag + " Exception at CompleteAsyncInference: " + e.what ());
742
788
}
743
789
0 commit comments