diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 2a842b8a1eca8..0b17500bc39e9 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -104,7 +104,8 @@ BackendManager::BackendManager(SessionContext& session_context, subgraph_context_.has_dynamic_input_shape = true; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if ((session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos) && + session_context_.device_type.find("GPU") != std::string::npos || + session_context_.device_type.find("NPU") != std::string::npos) && !session_context_.disable_dynamic_shapes) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; @@ -473,7 +474,8 @@ void BackendManager::Compute(OrtKernelContext* context) { if (subgraph_context_.has_dynamic_input_shape && !session_context_.disable_dynamic_shapes && (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { + session_context_.device_type.find("GPU") != std::string::npos || + session_context_.device_type.find("NPU") != std::string::npos)) { concrete_backend_->Infer(context); } else if (subgraph_context_.has_dynamic_input_shape) { std::vector> tensor_shapes = GetInputTensorShapes(ctx); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 3ac4d22f5453c..da29f3f1e3553 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -91,7 +91,8 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr exe_network_ = OVCore::Get()->ImportModel(*model_stream, hw_target, device_config, - subgraph_context_.subgraph_name); + session_context.onnx_model_path_name.string()); + model_stream.reset(); // Delete stream after it is no longer needed } else if (!session_context_.has_external_weights && !subgraph_context_.has_dynamic_input_shape && @@ -167,7 +168,6 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (session_context_.precision.find("ACCURACY") != std::string::npos && session_context_.device_type.find("GPU") != std::string::npos) { if (session_context_.OpenVINO_Version.at(0) >= 2024) { - device_config.emplace(ov::hint::inference_precision(ov::element::undefined)); device_config.emplace(ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)); } else { if (!subgraph_context_.model_precision.empty()) @@ -365,6 +365,13 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque } index++; } + + // for the stateful PoC, the ONNX model will have KV cache (past/present) tensors, but + // we internally converted it to stateful, which removed these. So, we just continue here + // to avoid runtime exception. + //if (input_name.empty()) continue; + if (input_name.empty() || input_name == "beam_idx") continue; + ORT_ENFORCE(!input_name.empty(), log_tag, "Input names mismatch between OpenVINO and ONNX. ", onnx_input_name, " doesn't exist in the list of OpenVINO input tensor names"); @@ -372,7 +379,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque if (subgraph_context_.has_dynamic_input_shape && !session_context_.disable_dynamic_shapes && (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { + session_context_.device_type.find("GPU") != std::string::npos || + session_context_.device_type.find("NPU") != std::string::npos)) { auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); auto tensor_shape = tensor_info.GetShape(); @@ -434,7 +442,10 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque } } // Loop subgraph original input names - if (session_context_.device_type.find("NPU") != std::string::npos) { + // For stateful PoC added '&& false' here to disable it, as we forced it through + // same dynamic shape path above as we do for CPU & GPU. + if (session_context_.device_type.find("NPU") != std::string::npos && + !subgraph_context_.has_dynamic_input_shape && false) { // Set the output blob as remote blob auto graph_output_info = exe_network_.Get().outputs(); auto output_idx = 0; @@ -629,7 +640,8 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe "list of OpenVINO output tensor names"); } if ((session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { + session_context_.device_type.find("GPU") != std::string::npos || + session_context_.device_type.find("NPU") != std::string::npos)) { try { graph_output_blob = infer_request->GetTensor(output_name); } catch (const char* msg) { diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 95e039f8b6d5f..e08b68be25dfa 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -14,6 +14,7 @@ namespace onnxruntime { namespace openvino_ep { + void ParseConfigOptions(ProviderInfo& pi, const ConfigOptions& config_options) { pi.so_disable_cpu_ep_fallback = config_options.GetConfigOrDefault(kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; pi.so_context_enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; @@ -330,7 +331,9 @@ struct OpenVINO_Provider : Provider { // Always true for NPU plugin or when passed . if (pi.device_type.find("NPU") != std::string::npos) { - pi.disable_dynamic_shapes = true; + // For Stateful PoC, we want control to pass through dynamic shape paths, + // so just force this to false right now. + pi.disable_dynamic_shapes = false; } // Append values to config to support weight-as-inputs conversion for shared contexts diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 37f9e1c4e9201..27fe8c7c68335 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -8,6 +8,9 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/backend_utils.h" +// for make stateful utility function(s) +#include "core/providers/openvino/ov_stateful_patch_utils.h" + using Exception = ov::Exception; namespace onnxruntime { @@ -77,7 +80,52 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_netwo const std::string& name) { ov::CompiledModel obj; try { - obj = core.compile_model(ie_cnn_network, hw_target, device_config); + if (true) { + ov::AnyMap config; + + // Create a clone of ie_cnn_network, since it's a const ov::Model, and we need to patch it.. + // Note! With this default path, the model runs but produces garbage (for NPUW). For CPU it's fine. + auto mutable_model = ie_cnn_network->clone(); + + // uncomment to override ov::Model with one produced by OV's ONNX front-end. + // For some reason, this makes it work -- even though model.onnx is the same model read by ORT GenAI. + // auto mutable_model = core.read_model("C:\\Users\\LNL\\Workspace\\ORT\\deepseek_r1_distill_qwen_1.5B_int4_ort_qdq\\model.onnx"); + + std::cout << "stateless model" << std::endl; + logBasicModelInfo(mutable_model); + + std::cout << "making stateful..." << std::endl; + patch_stateful_decoder(mutable_model); + + std::cout << "after stateful transition:" << std::endl; + logBasicModelInfo(mutable_model); + + // This patches the model so that it only produces the logits required for sampling. + // Actually either way that happens within NPUW::LLMCompiledModel creation, but this is + // here mostly to align this behavior for other devices (CPU, GPU). + apply_slice_before_matmul_transformation(mutable_model); + + auto kv_pos = get_kv_axes_pos(mutable_model); + std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl; + std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl; + + if (hw_target.find("NPU") != std::string::npos) { + KVDesc kv_desc; + kv_desc.max_prompt_len = pop_int_and_cast(device_config, "MAX_PROMPT_LEN").value_or(1024u); + kv_desc.min_response_len = pop_int_and_cast(device_config, "MIN_RESPONSE_LEN").value_or(128u); + + std::cout << "kv_desc.max_prompt_len = " << kv_desc.max_prompt_len << std::endl; + std::cout << "kv_desc.min_response_len = " << kv_desc.min_response_len << std::endl; + + update_npu_config(config, mutable_model, kv_pos, kv_desc); + } + + std::cout << "calling compile on stateful model..." << std::endl; + obj = core.compile_model(mutable_model, hw_target, config); + std::cout << "done calling compile on stateful model..." << std::endl; + } else { + obj = core.compile_model(ie_cnn_network, hw_target, device_config); + } #ifndef NDEBUG printDebugInfo(obj); #endif @@ -115,7 +163,83 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream, std::string name) { try { ov::CompiledModel obj; - obj = core.import_model(model_stream, hw_target, device_config); + + //Check if it's XML + std::streampos originalPos = model_stream.tellg(); + std::string header(5, '\0'); // Allocate space for "= 5 && name.substr(name.size() - 5) == ".onnx") { + bin_file = name; + bin_file.replace(name.size() - 5, 5, ".bin"); + } else { + throw std::runtime_error("Invalid model name. Make sure *.onnx, *.xml, and *.bin carry the same name." ); + } + + // Read the model XML into a string + std::stringstream xml_stream; + xml_stream << model_stream.rdbuf(); + std::string xml_content = xml_stream.str(); + + // Read model.bin into a vector + std::ifstream bin_stream; + bin_stream.open(bin_file, std::ios::binary); + if (!bin_stream.is_open()) { + throw std::runtime_error("Failed to open " + bin_file); + } + + bin_stream.seekg(0, std::ios::end); + std::streamsize size = bin_stream.tellg(); + bin_stream.seekg(0, std::ios::beg); + std::vector bin_data(size); + if (!bin_stream.read(reinterpret_cast(bin_data.data()), size)) { + throw std::runtime_error("Failed to read binary data from " + bin_file); + } + + // Create an ov::Tensor for weights + ov::Tensor weights_tensor(ov::element::u8, {bin_data.size()}, bin_data.data()); + + // Load the model explicitly with XML content and weights + std::shared_ptr model = core.read_model(xml_content, weights_tensor); + + + ov::AnyMap config = device_config; + + std::cout << "already a stateful model since it came from EPCtx:" << std::endl; + logBasicModelInfo(model); + + auto kv_pos = get_kv_axes_pos(model); + std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl; + std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl; + + if (hw_target.find("NPU") != std::string::npos) { + KVDesc kv_desc; + kv_desc.max_prompt_len = pop_int_and_cast(config, "MAX_PROMPT_LEN").value_or(1024u); + kv_desc.min_response_len = pop_int_and_cast(config, "MIN_RESPONSE_LEN").value_or(128u); + + std::cout << "kv_desc.max_prompt_len = " << kv_desc.max_prompt_len << std::endl; + std::cout << "kv_desc.min_response_len = " << kv_desc.min_response_len << std::endl; + + update_npu_config(config, model, kv_pos, kv_desc); + } else { + apply_slice_before_matmul_transformation(model); + } + + std::cout << "calling compile on stateful model for" << hw_target << " ... " << std::endl; + obj = core.compile_model(model, hw_target, config); + std::cout << "done calling compile on stateful model..." << std::endl; + } #ifndef NDEBUG printDebugInfo(obj); #endif @@ -128,6 +252,9 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream, } } + + + void OVCore::SetCache(const std::string& cache_dir_path) { core.set_property(ov::cache_dir(cache_dir_path)); } @@ -211,6 +338,16 @@ std::string OVInferRequest::GetInputTensorName(uint32_t index) { void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) { try { ovInfReq.set_tensor(name, *(blob.get())); + + if (name == "input_ids") { + // Since we can't seem to set at ORT GenAI layer right now, we just set it here + // as a workaround. + // TODO: Fix this. + ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {1}); + std::fill_n(beam_idx.data(), 1, 0); + ovInfReq.set_tensor("beam_idx", beam_idx); + } + } catch (const Exception& e) { ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name + e.what()); } catch (...) { diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h new file mode 100644 index 0000000000000..be8b783bfeb5b --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h @@ -0,0 +1,338 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +// for make stateful utility function(s) +#include "openvino/pass/manager.hpp" +#include "openvino/pass/make_stateful.hpp" +#include "openvino/opsets/opset13.hpp" + +static inline void logBasicModelInfo(const std::shared_ptr& model) { + std::cout << "Model name: " << model->get_friendly_name() << std::endl; + + // Dump information about model inputs/outputs + auto inputs = model->inputs(); + auto outputs = model->outputs(); + + std::cout << "\tInputs: " << std::endl; + for (const ov::Output& input : inputs) { + const std::string name = input.get_any_name(); + const ov::element::Type type = input.get_element_type(); + const ov::PartialShape shape = input.get_partial_shape(); + const ov::Layout layout = ov::layout::get_layout(input); + + std::cout << "\t\t" << name << ", " << type << ", " << shape << ", " << layout.to_string() << std::endl; + } + + std::cout << "\tOutputs: " << std::endl; + for (const ov::Output& output : outputs) { + const std::string name = output.get_any_name(); + const ov::element::Type type = output.get_element_type(); + const ov::PartialShape shape = output.get_partial_shape(); + const ov::Layout layout = ov::layout::get_layout(output); + + std::cout << "\t\t" << name << ", " << type << ", " << shape << ", " << layout.to_string() << std::endl; + } + + return; +} + +static inline bool model_has_input_output_names(std::shared_ptr model, const std::string& name_to_match) { + for (const ov::Output& input : model->inputs()) { + auto& names = input.get_names(); + + for (auto& name : names) { + if (name == name_to_match) { + return true; + } + } + } + + for (const ov::Output& output : model->outputs()) { + auto& names = output.get_names(); + for (auto& name : names) { + if (name == name_to_match) { + return true; + } + } + } + + return false; +} + +static void fuse_cache_reorder(std::shared_ptr ov_model, + std::vector& not_kv_inputs, + const std::vector& key_value_input_names, + int gather_dim) { + if (model_has_input_output_names(ov_model, "beam_idx")) { + throw std::runtime_error("Model already has fused cache"); + } + + std::string main_input_name = "inputs_embeds"; + if (model_has_input_output_names(ov_model, "input_ids")) { + main_input_name = "input_ids"; + } + + auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; + + auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({input_batch})); + beam_idx->set_friendly_name("beam_idx"); + beam_idx->output(0).get_tensor().add_names({"beam_idx"}); + ov_model->add_parameters({beam_idx}); + not_kv_inputs.push_back(beam_idx->get_friendly_name()); + + // Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx + for (const auto& input_name : key_value_input_names) { + auto parameter_output_port = ov_model->input(input_name); + auto consumers = parameter_output_port.get_target_inputs(); + + auto gather_op = + std::make_shared(parameter_output_port, + beam_idx, + ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim})); + + // Replace the source output for all consumers of the input tensor + for (auto& consumer : consumers) { + consumer.replace_source_output(gather_op->output(0)); + } + } + + // Validate the modified model + ov_model->validate_nodes_and_infer_types(); +} + +static void make_stateful(std::shared_ptr& ov_model, + const std::vector& key_value_input_names, + const std::vector& key_value_output_names) { + std::map input_output_map; + + // Create mapping for KV-cache inputs and outputs + for (size_t i = 0; i < key_value_input_names.size(); ++i) { + input_output_map[key_value_input_names[i]] = key_value_output_names[i]; + } + + // Apply the transformation to make the model stateful + ov::pass::Manager manager; + manager.register_pass(input_output_map); + manager.run_passes(ov_model); +} + +// Converted to C++ from here: +// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 +static void patch_stateful_decoder(std::shared_ptr model) { + std::vector key_value_input_names; + std::vector not_kv_inputs; + for (const ov::Output& input : model->inputs()) { + auto& names = input.get_names(); + + bool found = false; + for (auto& name : names) { + if (name.find("key_values") != std::string::npos) { + key_value_input_names.push_back(name); + found = true; + break; + } + } + + if (!found) { + not_kv_inputs.push_back(input.get_any_name()); + } + } + + std::vector key_value_output_names; + for (const ov::Output& output : model->outputs()) { + auto& names = output.get_names(); + for (auto& name : names) { + if (name.find("present") != std::string::npos) { + key_value_output_names.push_back(name); + break; + } + } + } + + if (key_value_input_names.empty() || key_value_output_names.empty()) { + std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; + return; + } + + // By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch + // TODO: Deduce from a model via ordinal reshape(? ) and topology + // batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0 + auto batch_dim = 0; + + fuse_cache_reorder(model, not_kv_inputs, key_value_input_names, batch_dim); + + make_stateful(model, key_value_input_names, key_value_output_names); +} + +// Some other utility functions copied from OpenVINO GenAI +static bool has_op_with_type(const std::shared_ptr& function, const std::string& type_name) { + for (const auto& op : function->get_ops()) { + if (op->get_type_name() == type_name) { + return true; + } + } + return false; +} + +static std::tuple, int64_t> find_llm_matmul(const std::shared_ptr& model) { + auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr(); + std::shared_ptr matmul = ov::as_type_ptr(last_node); + + // in case of PA all tokens are moved to batch dimension and we have to slice / gather accordingly + const bool pa_based_model = has_op_with_type(model, "PagedAttentionExtension"); + int64_t slice_gather_dim = pa_based_model ? 0 : 1; + + // There are several patterns for matmul we are looking for: + // Matmul -> Result + // Matmul -> Add -> Result + // Matmul -> Transpose -> Result + // MatMul -> Divide -> Tanh -> Multiply -> Result + if (!matmul) { + if (auto add = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(add->input_value(0).get_node_shared_ptr()); + } else if (auto transpose = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(transpose->input_value(0).get_node_shared_ptr()); + auto order = ov::as_type_ptr(transpose->input_value(1).get_node_shared_ptr())->get_axis_vector_val(); + slice_gather_dim = order[slice_gather_dim]; + } else if (auto multiply = ov::as_type_ptr(last_node)) { + if (auto tanh = ov::as_type_ptr(multiply->input_value(0).get_node_shared_ptr())) { + if (auto divide = ov::as_type_ptr(tanh->input_value(0).get_node_shared_ptr())) { + matmul = as_type_ptr(divide->input_value(0).get_node_shared_ptr()); + } + } + } + } + return std::make_tuple(matmul, slice_gather_dim); +} + +static void apply_slice_before_matmul_transformation(std::shared_ptr model) { + std::shared_ptr matmul = nullptr; + int64_t slice_gather_dim = -1; + std::tie(matmul, slice_gather_dim) = find_llm_matmul(model); + + if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { + auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-2}); + auto step = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{slice_gather_dim}); + auto slice = std::make_shared(matmul->input_value(0), start, stop, step, axis); + matmul->input(0).replace_source_output(slice); + } +} + +static void update_config(ov::AnyMap& config, const std::pair& pair) { + if (config.count(pair.first) == 0) { + config.insert(pair); + } +} + +static std::optional pop_option(ov::AnyMap& config, const std::string& option_name) { + if (auto it = config.find(option_name); it != config.end()) { + std::optional found = std::make_optional(it->second); + config.erase(it); + return found; + } + return std::nullopt; +} + +static void rename_key(ov::AnyMap& config, const std::string& old_key, const std::string& new_key) { + if (config.count(old_key) != 0) { + auto opt_value = pop_option(config, old_key); + config[new_key] = opt_value.value(); + } +} + +struct KVAxesPosition { + size_t batch; + size_t seq_len; +}; + +KVAxesPosition get_kv_axes_pos(std::shared_ptr model) { + // sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size], + // therefore usually seq_length_axis = 2 and batch = 0 + KVAxesPosition kv_pos{0u, 2u}; + + // "ReadValue" node is KV cache representation in stateful model + std::string kv_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name); + + for (const auto op : model->get_ops()) { + // check input size, as in LoRA adapters case it could be 0 + if (op->get_type_name() != kv_node_type_name || op->get_input_size() < 1) { + continue; + } + + // Shape example: [-1,4,0,64] + auto shape = op->get_input_partial_shape(0); + + for (int64_t i = 0; i < shape.rank().get_length(); i++) { + // Find axis = 0. This would be sequence length axis. + if (shape[i] == 0) { + kv_pos.seq_len = i; + } else if (shape[i].is_dynamic()) { + // Dynamic axis is a batch + kv_pos.batch = i; + } + } + break; + } + + return kv_pos; +} + +struct KVDesc { + uint32_t max_prompt_len; + uint32_t min_response_len; +}; + +static void update_npu_config(ov::AnyMap& config, + const std::shared_ptr& model, + const KVAxesPosition& kv_pos, + const KVDesc& kv_desc) { + update_config(config, {"NPU_USE_NPUW", "YES"}); + update_config(config, {"NPUW_LLM", "YES"}); + + update_config(config, {"NPUW_LLM_BATCH_DIM", kv_pos.batch}); + update_config(config, {"NPUW_LLM_SEQ_LEN_DIM", kv_pos.seq_len}); + + update_config(config, {"NPUW_LLM_MAX_PROMPT_LEN", kv_desc.max_prompt_len}); + update_config(config, {"NPUW_LLM_MIN_RESPONSE_LEN", kv_desc.min_response_len}); + + rename_key(config, "++PREFILL_CONFIG", "++NPUW_LLM_PREFILL_CONFIG"); + rename_key(config, "++GENERATE_CONFIG", "++NPUW_LLM_GENERATE_CONFIG"); + rename_key(config, "PREFILL_CONFIG", "NPUW_LLM_PREFILL_CONFIG"); + rename_key(config, "PREFILL_HINT", "NPUW_LLM_PREFILL_HINT"); + rename_key(config, "GENERATE_CONFIG", "NPUW_LLM_GENERATE_CONFIG"); + rename_key(config, "GENERATE_HINT", "NPUW_LLM_GENERATE_HINT"); +} + +static std::optional pop_option_new(ov::AnyMap& config, const std::string& option_name) { + if (auto it = config.find(option_name); it != config.end()) { + std::optional found = std::make_optional(it->second); + config.erase(it); + return found; + } + return std::nullopt; +} + +static std::optional pop_int_and_cast(ov::AnyMap& config, const std::string& key) { + auto anyopt = pop_option_new(config, key); + if (anyopt.has_value()) { + const auto any = anyopt.value(); + int64_t value; + // NB: Integer value coming from python has int64_t datatype + if (any.is()) { + value = any.as(); + } else if (any.is()) { + value = any.as(); + } else { + OPENVINO_THROW("Failed to extract " + key + ". Type mismatch: expected types: int or int64_t"); + } + if (value < 0) { + OPENVINO_THROW(key + " cannot be negative!"); + } + return std::make_optional(static_cast(value)); + } + return std::nullopt; +} diff --git a/onnxruntime/test/testdata/custom_op_openvino_wrapper_library/openvino_wrapper.cc b/onnxruntime/test/testdata/custom_op_openvino_wrapper_library/openvino_wrapper.cc index 27d5c59439243..d4ce3320e13ca 100644 --- a/onnxruntime/test/testdata/custom_op_openvino_wrapper_library/openvino_wrapper.cc +++ b/onnxruntime/test/testdata/custom_op_openvino_wrapper_library/openvino_wrapper.cc @@ -35,7 +35,7 @@ static ov::element::Type ConvertONNXToOVType(ONNXTensorElementDataType onnx_type case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return ov::element::bf16; default: - return ov::element::undefined; + return ov::element::dynamic; } }