Skip to content

Commit 6d67616

Browse files
committed
feat: ORT GenAI Stateful Compilation changes
1 parent e354009 commit 6d67616

13 files changed

+796
-37
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ BackendManager::BackendManager(SessionContext& session_context,
105105
subgraph_context_.has_dynamic_input_shape = true;
106106
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims";
107107
if ((session_context_.device_type.find("CPU") != std::string::npos ||
108-
session_context_.device_type.find("GPU") != std::string::npos) &&
108+
session_context_.device_type.find("GPU") != std::string::npos ||
109+
(session_context_.device_type.find("NPU") != std::string::npos && session_context_.enable_causallm)) &&
109110
!session_context_.disable_dynamic_shapes) {
110111
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. "
111112
<< "Creating backend Dynamic Shapes";
@@ -492,7 +493,9 @@ void BackendManager::Compute(OrtKernelContext* context) {
492493
if (subgraph_context_.has_dynamic_input_shape &&
493494
!session_context_.disable_dynamic_shapes &&
494495
(session_context_.device_type.find("CPU") != std::string::npos ||
495-
session_context_.device_type.find("GPU") != std::string::npos)) {
496+
session_context_.device_type.find("GPU") != std::string::npos ||
497+
(session_context_.device_type.find("NPU") != std::string::npos &&
498+
session_context_.enable_causallm))) {
496499
concrete_backend_->Infer(context);
497500
} else if (subgraph_context_.has_dynamic_input_shape) {
498501
std::vector<std::vector<int64_t>> tensor_shapes = GetInputTensorShapes(ctx);
@@ -565,5 +568,11 @@ void BackendManager::ShutdownBackendManager() {
565568
concrete_backend_.reset();
566569
}
567570

571+
void BackendManager::RewindKVCache(size_t index) {
572+
if (concrete_backend_) {
573+
concrete_backend_->RewindKVCache(index);
574+
}
575+
}
576+
568577
} // namespace openvino_ep
569578
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class BackendManager {
3030
SessionContext& GetSessionContext();
3131
Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph);
3232
ov::CompiledModel& GetOVCompiledModel();
33+
void RewindKVCache(size_t index);
3334

3435
private:
3536
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "core/providers/openvino/backends/basic_backend.h"
1616
#include "core/providers/openvino/onnx_ctx_model_helper.h"
1717
#include "core/providers/openvino/backend_manager.h"
18+
#include "core/providers/openvino/ov_stateful_patch_utils.h"
1819

1920
namespace onnxruntime {
2021

@@ -29,6 +30,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
2930
ptr_stream_t& model_stream)
3031
: session_context_{session_context}, subgraph_context_{subgraph_context}, shared_context_{shared_context} {
3132
std::string& hw_target = session_context_.device_type;
33+
bool enable_causallm = session_context_.enable_causallm;
3234

3335
if (ValidateSubgraph(const_outputs_map_))
3436
return;
@@ -43,7 +45,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
4345
// Setting OpenCL queue throttling for GPU
4446
EnableGPUThrottling(device_config);
4547

46-
// Enable streams; default=1 unless ovverriden by user config
48+
// Enable streams; default=1 unless overridden by user configuration
4749
EnableStreams();
4850

4951
// Set the inference_num_threads property of the CPU
@@ -95,7 +97,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
9597
} else if (!session_context_.has_external_weights &&
9698
!subgraph_context_.has_dynamic_input_shape &&
9799
!session_context_.so_context_enable &&
98-
auto_unified_compile) {
100+
!enable_causallm && auto_unified_compile) {
99101
// Unified OV compile_model is efficient when ov model caching is enabled
100102
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
101103
// Inputs with static dimensions
@@ -115,7 +117,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
115117
}
116118
auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_);
117119
exe_network_ = OVCore::Get()->CompileModel(
118-
ov_model, hw_target, device_config, subgraph_context_.subgraph_name);
120+
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name);
119121
}
120122
#endif
121123
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
@@ -200,6 +202,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
200202
if (!session_context_.load_config.empty()) {
201203
const std::map<std::string, ov::AnyMap>& target_config = session_context_.load_config;
202204

205+
if ((session_context_.device_type.find("NPU") != std::string::npos) && session_context_.enable_causallm) {
206+
if (target_config.find("NPU") != target_config.end()) {
207+
auto npu_genai_config = target_config.at("NPU");
208+
CausalLMConfig().ApplyConfig(npu_genai_config, device_config);
209+
} else {
210+
LOGS_DEFAULT(WARNING) << "ORT GenAI CausalLMConfig Configuration not found.";
211+
}
212+
}
213+
203214
if (session_context_.device_type.find("NPU") != std::string::npos) {
204215
auto npuw_config = target_config.at("NPU");
205216

@@ -265,7 +276,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
265276
auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options,
266277
const std::vector<ov::PropertyName>& supported_properties) {
267278
for (const auto& [key, value] : config_options) {
268-
if (key.find("NPUW") != std::string::npos) {
279+
if ((key.find("NPUW") != std::string::npos) ||
280+
((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) {
269281
continue;
270282
}
271283
if (is_supported_and_mutable(key, supported_properties)) {
@@ -358,6 +370,13 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
358370
device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads));
359371
}
360372

373+
void BasicBackend::RewindKVCache(size_t index) {
374+
OVInferRequestPtr infer_request;
375+
infer_request = inferRequestsQueue_->getIdleRequest();
376+
infer_request->RewindKVCache(index);
377+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
378+
}
379+
361380
// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on
362381
// an Infer Request indexed by infer_req_idx
363382
void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) {
@@ -376,14 +395,22 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
376395
}
377396
index++;
378397
}
398+
399+
// For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors.
400+
// However, these tensors are internally converted to a stateful representation, which removes them.
401+
// To prevent runtime exceptions, we simply continue processing here.
402+
if (input_name.empty() || input_name == "beam_idx") continue;
403+
379404
ORT_ENFORCE(!input_name.empty(), log_tag,
380405
"Input names mismatch between OpenVINO and ONNX. ", onnx_input_name,
381406
" doesn't exist in the list of OpenVINO input tensor names");
382407
size_t batch_slice_idx = 0;
383408
if (subgraph_context_.has_dynamic_input_shape &&
384409
!session_context_.disable_dynamic_shapes &&
385410
(session_context_.device_type.find("CPU") != std::string::npos ||
386-
session_context_.device_type.find("GPU") != std::string::npos)) {
411+
session_context_.device_type.find("GPU") != std::string::npos ||
412+
(session_context_.device_type.find("NPU") != std::string::npos &&
413+
session_context_.enable_causallm))) {
387414
auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name));
388415
auto tensor_info = tensor.GetTensorTypeAndShapeInfo();
389416
auto tensor_shape = tensor_info.GetShape();
@@ -445,7 +472,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
445472
}
446473
} // Loop subgraph original input names
447474

448-
if (session_context_.device_type.find("NPU") != std::string::npos) {
475+
// For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path for NPU plugin as well.
476+
if (session_context_.device_type.find("NPU") != std::string::npos && !session_context_.enable_causallm) {
449477
// Set the output blob as remote blob
450478
auto graph_output_info = exe_network_.Get().outputs();
451479
auto output_idx = 0;
@@ -640,7 +668,9 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
640668
"list of OpenVINO output tensor names");
641669
}
642670
if ((session_context_.device_type.find("CPU") != std::string::npos ||
643-
session_context_.device_type.find("GPU") != std::string::npos)) {
671+
session_context_.device_type.find("GPU") != std::string::npos ||
672+
(session_context_.device_type.find("NPU") != std::string::npos &&
673+
session_context_.enable_causallm))) {
644674
try {
645675
graph_output_blob = infer_request->GetTensor(output_name);
646676
} catch (const char* msg) {
@@ -719,25 +749,41 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
719749
try {
720750
StartRemoteAsyncInference(context, infer_request);
721751
} catch (std::string const& msg) {
752+
// If the inference fails (exception from ov::InferRequest::infer()),
753+
// we need to put the infer_request back into the pool to avoid deadlocks
754+
// and to allow the next inference request to proceed.
755+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
722756
ORT_THROW(msg);
723757
}
724758
} else {
725759
try {
726760
StartAsyncInference(context, infer_request);
727761
} catch (std::string const& msg) {
762+
// If the inference fails (exception from ov::InferRequest::infer()),
763+
// we need to put the infer_request back into the pool to avoid deadlocks
764+
// and to allow the next inference request to proceed.
765+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
728766
ORT_THROW(msg);
729767
}
730768
}
731769
#else
732770
try {
733771
StartAsyncInference(context, infer_request);
734772
} catch (const std::runtime_error& e) {
773+
// If the inference fails (exception from ov::InferRequest::infer()),
774+
// we need to put the infer_request back into the pool to avoid deadlocks
775+
// and to allow the next inference request to proceed.
776+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
735777
ORT_THROW(log_tag + " Exception at StartAsyncInference: " + e.what());
736778
}
737779
#endif
738780
try {
739781
CompleteAsyncInference(context, infer_request);
740782
} catch (const std::runtime_error& e) {
783+
// If the inference fails (exception from ov::InferRequest::infer()),
784+
// we need to put the infer_request back into the pool to avoid deadlocks
785+
// and to allow the next inference request to proceed.
786+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
741787
ORT_THROW(log_tag + " Exception at CompleteAsyncInference: " + e.what());
742788
}
743789

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class BasicBackend : public IBackend {
4141
ov::CompiledModel& GetOVCompiledModel() override {
4242
return exe_network_.Get();
4343
}
44+
void RewindKVCache(size_t index) override;
4445

4546
private:
4647
void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&);
@@ -78,7 +79,7 @@ class InferRequestsQueue {
7879
InferRequestsQueue(OVExeNetwork& net, size_t nireq, std::function<void(OVInferRequestPtr)> initializer) {
7980
OVInferRequestPtr infer_request;
8081
for (size_t id = 0; id < nireq; id++) {
81-
infer_request = std::make_shared<OVInferRequest>(net.CreateInferRequest());
82+
infer_request = net.CreateInferRequest();
8283
initializer(infer_request);
8384
infer_requests_.push_back(infer_request);
8485
}

onnxruntime/core/providers/openvino/contexts.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ struct ProviderInfo {
9797
bool disable_dynamic_shapes{false}; // [disable_dynamic_shapes]: Rewrite dynamic shaped models to
9898
// static shape at runtime and execute.
9999
bool enable_qdq_optimizer{false}; // Enables QDQ pruning for efficient inference latency with NPU
100+
bool enable_causallm{false}; // Enables Causal LM Compilation for ORT GenAI OVEP Pass
100101
bool so_context_enable{false}; // ORT session option
101102
bool so_disable_cpu_ep_fallback{false}; // ORT session option
102103
bool so_context_embed_mode{false}; // ORT session option

onnxruntime/core/providers/openvino/ibackend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class IBackend {
1717
virtual void Infer(OrtKernelContext* context) = 0;
1818
virtual ov::CompiledModel& GetOVCompiledModel() = 0;
1919
virtual ~IBackend() = default;
20+
virtual void RewindKVCache(size_t index) {};
2021
};
2122
using ptr_stream_t = std::unique_ptr<std::istream>;
2223
class BackendFactory {

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,25 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
244244
ov_compiled_model.set_property(ov::workload_type(workload_type));
245245
}
246246
}
247+
} else if (key == "kvcache_rewind") {
248+
// Convert kvcache_rewind value to int64_t
249+
int64_t index;
250+
try {
251+
index = std::stoll(value);
252+
} catch (const std::exception& e) {
253+
LOGS_DEFAULT(WARNING) << "Conversion for kvcache_rewind string value to int64_t index failed."
254+
<< "Exception:" + std::string(e.what());
255+
return Status::OK();
256+
}
257+
258+
// Trigger KVCache Rewind for target Backend
259+
for (auto& backend : backend_managers_) {
260+
if (index >= 0) {
261+
backend.RewindKVCache(static_cast<size_t>(index));
262+
} else {
263+
LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index;
264+
}
265+
}
247266
} else {
248267
// Handle unknown options
249268
LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value;

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,13 +343,20 @@ static void ParseProviderInfo(const ProviderOptions& provider_options,
343343

344344
pi.enable_qdq_optimizer = ParseBooleanOption(provider_options, "enable_qdq_optimizer");
345345

346+
pi.enable_causallm = ParseBooleanOption(provider_options, "enable_causallm");
347+
346348
pi.disable_dynamic_shapes = ParseBooleanOption(provider_options, "disable_dynamic_shapes");
347349
} catch (std::string msg) {
348350
ORT_THROW(msg);
349351
}
350352
// Always true for NPU plugin or when passed .
351353
if (pi.device_type.find("NPU") != std::string::npos) {
352-
pi.disable_dynamic_shapes = true;
354+
// For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path.
355+
if (pi.enable_causallm) {
356+
pi.disable_dynamic_shapes = false;
357+
} else {
358+
pi.disable_dynamic_shapes = true;
359+
}
353360
}
354361
}
355362

0 commit comments

Comments
 (0)