Skip to content

Commit

Permalink
Fix w/ Settings
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani authored Oct 26, 2024
1 parent ee27c9f commit d4eec76
Showing 1 changed file with 119 additions and 149 deletions.
268 changes: 119 additions & 149 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,10 @@
#include <sstream>
#include <mutex>
#include <iostream>
#include <yyjson.hpp>

#include "yyjson.hpp"

namespace duckdb {
struct OpenPromptData: FunctionData {
unique_ptr<FunctionData> Copy() const {
throw std::runtime_error("OpenPromptData::Copy");
};
bool Equals(const FunctionData &other) const {
throw std::runtime_error("OpenPromptData::Equals");
};
};

// Helper function to parse URL and setup client

static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(const std::string &url) {
std::string scheme, domain, path;
size_t pos = url.find("://");
Expand All @@ -46,7 +36,6 @@ static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(co
path = "/";
}

// Create client and set a reasonable timeout (e.g., 10 seconds)
duckdb_httplib_openssl::Client client(domain.c_str());
client.set_read_timeout(10, 0); // 10 seconds
client.set_follow_location(true); // Follow redirects
Expand Down Expand Up @@ -98,184 +87,167 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std
throw std::runtime_error(err_message);
}


// Open Prompt
// Global settings
static std::string api_url = "http://localhost:11434/v1/chat/completions";
static std::string api_token; // Store your API token here
static std::string model_name = "qwen2.5:0.5b"; // Default model
static std::mutex settings_mutex;

// Function to set API token
void SetApiToken(DataChunk &args, ExpressionState &state, Vector &result) {
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t token) {
try {
auto _token = token.GetData();
if (token.Empty()) {
throw std::invalid_argument("API token cannot be empty.");
}
ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
"openprompt_api_token",
Value::CreateValue(token.GetString()));
return StringVector::AddString(result, string("token : ") + string(_token, token.GetSize()));
} catch (std::exception &e) {
string_t res(e.what());
res.Finalize();
return res;
}
});
// Settings management
static std::string GetConfigValue(ClientContext &context, const string &var_name, const string &default_value) {
Value value;
auto &config = ClientConfig::GetConfig(context);
if (!config.GetUserVariable(var_name, value) || value.IsNull()) {
return default_value;
}
return value.ToString();
}

// Function to set API URL
void SetApiUrl(DataChunk &args, ExpressionState &state, Vector &result) {
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t token) {
static void SetConfigValue(DataChunk &args, ExpressionState &state, Vector &result,
const string &var_name, const string &value_type) {
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t value) {
try {
auto _token = token.GetData();
if (token.Empty()) {
throw std::invalid_argument("API token cannot be empty.");
if (value == "" || value.GetSize() == 0) {
throw std::invalid_argument(value_type + " cannot be empty.");
}

ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
"openprompt_api_url",
Value::CreateValue(token.GetString()));
return StringVector::AddString(result, string("url : ") + string(_token, token.GetSize()));
var_name,
Value::CreateValue(value.GetString())
);
return StringVector::AddString(result, value_type + " set to: " + value.GetString());
} catch (std::exception &e) {
string_t res(e.what());
res.Finalize();
return res;
return StringVector::AddString(result, "Failed to set " + value_type + ": " + e.what());
}
});
}

// Function to set model name
void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) {
UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t token) {
try {
auto _token = token.GetData();
if (token.Empty()) {
throw std::invalid_argument("API token cannot be empty.");
}
ClientConfig::GetConfig(state.GetContext()).SetUserVariable(
"openprompt_model_name",
Value::CreateValue(token.GetString()));
return StringVector::AddString(result, string("name : ") + string(_token, token.GetSize()));
} catch (std::exception &e) {
string_t res(e.what());
res.Finalize();
return res;
}
});
}
}

// Retrieve the API URL from the stored settings
static std::string GetApiUrl() {
std::lock_guard<std::mutex> guard(settings_mutex);
return api_url.empty() ? "http://localhost:11434/v1/chat/completions" : api_url;
}
static void SetApiToken(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_api_token", "API token");
}

// Retrieve the API token from the stored settings
static std::string GetApiToken() {
std::lock_guard<std::mutex> guard(settings_mutex);
return api_token;
}
static void SetApiUrl(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_api_url", "API URL");
}

// Retrieve the model name from the stored settings
static std::string GetModelName() {
std::lock_guard<std::mutex> guard(settings_mutex);
return model_name.empty() ? "qwen2.5:0.5b" : model_name;
}
static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) {
SetConfigValue(args, state, result, "openprompt_model_name", "Model name");
}

template<typename a> a assert_null(a val) {
if (val == nullptr) {
throw std::runtime_error("Failed to parse the first message content in the API response.");
}
return val;
}
// Open Prompt Function
// Main Function
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() >= 1); // At least prompt required

UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t user_prompt) {
auto &conf = ClientConfig::GetConfig(state.GetContext());
Value api_url;
Value api_token;
Value model_name;
conf.GetUserVariable("openprompt_api_url", api_url);
conf.GetUserVariable("openprompt_api_token", api_token);
conf.GetUserVariable("openprompt_model_name", model_name);

// Manually construct the JSON body as a string. TODO use json parser from extension.
auto &context = state.GetContext();

// Get configuration with defaults
std::string api_url = GetConfigValue(context, "openprompt_api_url",
"http://localhost:11434/v1/chat/completions");
std::string api_token = GetConfigValue(context, "openprompt_api_token", "");
std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b");

// Override model if provided as second argument
if (args.data.size() > 1 && !args.data[1].GetValue(0).IsNull()) {
model_name = args.data[1].GetValue(0).ToString();
}

std::string request_body = "{";
request_body += "\"model\":\"" + model_name.ToString() + "\",";
request_body += "\"model\":\"" + model_name + "\",";
request_body += "\"messages\":[";
request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},";
request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}";
request_body += "]}";

try {
// Make the POST request
auto client_and_path = SetupHttpClient(api_url.ToString());
auto client_and_path = SetupHttpClient(api_url);
auto &client = client_and_path.first;
auto &path = client_and_path.second;

// Setup headers
duckdb_httplib_openssl::Headers header_map;
header_map.emplace("Content-Type", "application/json");
if (!api_token.ToString().empty()) {
header_map.emplace("Authorization", "Bearer " + api_token.ToString());
duckdb_httplib_openssl::Headers headers;
headers.emplace("Content-Type", "application/json");
if (!api_token.empty()) {
headers.emplace("Authorization", "Bearer " + api_token);
}

auto res = client.Post(path.c_str(), headers, request_body, "application/json");

if (!res) {
HandleHttpError(res, "POST");
}

if (res->status != 200) {
throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason);
}

// Send the request
auto res = client.Post(path.c_str(), header_map, request_body, "application/json");
if (res && res->status == 200) {
// Extract the first choice's message content from the response
std::string response_body = res->body;
unique_ptr<duckdb_yyjson::yyjson_doc, void(*)(struct duckdb_yyjson::yyjson_doc *)> doc(
nullptr, &duckdb_yyjson::yyjson_doc_free
);
doc.reset(assert_null(
duckdb_yyjson::yyjson_read(response_body.c_str(), response_body.length(), 0)
));
auto root = assert_null(duckdb_yyjson::yyjson_doc_get_root(doc.get()));
auto choices = assert_null(duckdb_yyjson::yyjson_obj_get(root, "choices"));
auto choices_0 = assert_null(duckdb_yyjson::yyjson_arr_get_first(choices));
auto message = assert_null(duckdb_yyjson::yyjson_obj_get(choices_0, "message"));
auto content = assert_null(duckdb_yyjson::yyjson_obj_get(message, "content"));
auto c_content = assert_null(duckdb_yyjson::yyjson_get_str(content));
return StringVector::AddString(result, c_content);
try {
unique_ptr<duckdb_yyjson::yyjson_doc, void(*)(duckdb_yyjson::yyjson_doc *)> doc(
duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0),
&duckdb_yyjson::yyjson_doc_free
);

if (!doc) {
throw std::runtime_error("Failed to parse JSON response");
}

auto root = duckdb_yyjson::yyjson_doc_get_root(doc.get());
if (!root) {
throw std::runtime_error("Invalid JSON response: no root object");
}

auto choices = duckdb_yyjson::yyjson_obj_get(root, "choices");
if (!choices || !duckdb_yyjson::yyjson_is_arr(choices)) {
throw std::runtime_error("Invalid response format: missing choices array");
}

auto first_choice = duckdb_yyjson::yyjson_arr_get_first(choices);
if (!first_choice) {
throw std::runtime_error("Empty choices array in response");
}

auto message = duckdb_yyjson::yyjson_obj_get(first_choice, "message");
if (!message) {
throw std::runtime_error("Missing message in response");
}

auto content = duckdb_yyjson::yyjson_obj_get(message, "content");
if (!content) {
throw std::runtime_error("Missing content in response");
}

auto content_str = duckdb_yyjson::yyjson_get_str(content);
if (!content_str) {
throw std::runtime_error("Invalid content in response");
}

return StringVector::AddString(result, content_str);
} catch (std::exception &e) {
throw std::runtime_error("Failed to parse response: " + std::string(e.what()));
}
throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason);
} catch (std::exception &e) {
// In case of any error, return the original input text to avoid disruption
return StringVector::AddString(result, e.what());
// Log error and return error message
return StringVector::AddString(result, "Error: " + std::string(e.what()));
}
});
}


// LoadInternal function
static void LoadInternal(DatabaseInstance &instance) {
// Register open_prompt function with two arguments: prompt and model
ScalarFunctionSet open_prompt("open_prompt");

// Register with both single and two-argument variants
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));

ExtensionUtil::RegisterFunction(instance, open_prompt);

// Other set_* functions remain the same as before
// Register setting functions
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
SetApiToken));

"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiToken));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
SetApiUrl));

"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl));
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName
));
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName));
}


void OpenPromptExtension::Load(DuckDB &db) {
LoadInternal(*db.instance);
}
Expand All @@ -292,7 +264,6 @@ std::string OpenPromptExtension::Version() const {
#endif
}


} // namespace duckdb

extern "C" {
Expand All @@ -309,4 +280,3 @@ DUCKDB_EXTENSION_API const char *open_prompt_version() {
#ifndef DUCKDB_EXTENSION_MAIN
#error DUCKDB_EXTENSION_MAIN not defined
#endif

0 comments on commit d4eec76

Please sign in to comment.