diff --git a/src/httpserver_extension.cpp b/src/httpserver_extension.cpp index 6f9d745..9d8d0a0 100644 --- a/src/httpserver_extension.cpp +++ b/src/httpserver_extension.cpp @@ -7,12 +7,14 @@ #include "duckdb/main/extension_util.hpp" #include "duckdb/common/atomic.hpp" #include "duckdb/common/exception/http_exception.hpp" +#include "duckdb/common/allocator.hpp" #define CPPHTTPLIB_OPENSSL_SUPPORT #include "httplib.hpp" #include #include +#include namespace duckdb { @@ -21,7 +23,8 @@ struct HttpServerState { std::unique_ptr server_thread; std::atomic is_running; DatabaseInstance* db_instance; - + unique_ptr allocator; + HttpServerState() : is_running(false), db_instance(nullptr) {} }; @@ -32,16 +35,16 @@ static void HandleQuery(const string& query, duckdb_httplib_openssl::Response& r if (!global_state.db_instance) { throw IOException("Database instance not initialized"); } - + Connection con(*global_state.db_instance); auto result = con.Query(query); - + if (result->HasError()) { res.status = 400; res.set_content(result->GetError(), "text/plain"); return; } - + res.set_content(result->ToString(), "text/plain"); } catch (const Exception& ex) { res.status = 400; @@ -53,11 +56,14 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port) { if (global_state.is_running) { throw IOException("HTTP server is already running"); } - + global_state.db_instance = &db; - global_state.server.reset(new duckdb_httplib_openssl::Server()); + global_state.server = make_uniq(); global_state.is_running = true; - + + // Create a new allocator for the server thread + global_state.allocator = make_uniq(); + // Handle GET requests global_state.server->Get("/query", [](const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) { if (!req.has_param("q")) { @@ -65,7 +71,7 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port) { res.set_content("Missing query parameter 'q'", "text/plain"); return; } - + auto query = req.get_param_value("q"); HandleQuery(query, res); }); @@ -86,12 +92,12 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port) { }); string host_str = host.GetString(); - global_state.server_thread.reset(new std::thread([host_str, port]() { + global_state.server_thread = make_uniq([host_str, port]() { if (!global_state.server->listen(host_str.c_str(), port)) { global_state.is_running = false; throw IOException("Failed to start HTTP server on " + host_str + ":" + std::to_string(port)); } - })); + }); } void HttpServerStop() { @@ -104,9 +110,16 @@ void HttpServerStop() { global_state.server_thread.reset(); global_state.db_instance = nullptr; global_state.is_running = false; + + // Reset the allocator + global_state.allocator.reset(); } } +static void HttpServerCleanup() { + HttpServerStop(); +} + static void LoadInternal(DatabaseInstance &instance) { auto httpserve_start = ScalarFunction("httpserve_start", {LogicalType::VARCHAR, LogicalType::INTEGER}, @@ -114,7 +127,7 @@ static void LoadInternal(DatabaseInstance &instance) { [&](DataChunk &args, ExpressionState &state, Vector &result) { auto &host_vector = args.data[0]; auto &port_vector = args.data[1]; - + UnaryExecutor::Execute( host_vector, result, args.size(), [&](string_t host) { @@ -134,6 +147,10 @@ static void LoadInternal(DatabaseInstance &instance) { ExtensionUtil::RegisterFunction(instance, httpserve_start); ExtensionUtil::RegisterFunction(instance, httpserve_stop); + + // Register the cleanup function to be called at exit + std::atexit(HttpServerCleanup); + } void HttpserverExtension::Load(DuckDB &db) {