Skip to content

Commit

Permalink
std::atexit on parent
Browse files Browse the repository at this point in the history
Exit if the parent thread exists
  • Loading branch information
lmangani authored Oct 13, 2024
1 parent e0e16f2 commit d78d875
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions src/httpserver_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
#include "duckdb/main/extension_util.hpp"
#include "duckdb/common/atomic.hpp"
#include "duckdb/common/exception/http_exception.hpp"
#include "duckdb/common/allocator.hpp"

#define CPPHTTPLIB_OPENSSL_SUPPORT
#include "httplib.hpp"

#include <thread>
#include <memory>
#include <cstdlib>

namespace duckdb {

Expand All @@ -21,7 +23,8 @@ struct HttpServerState {
std::unique_ptr<std::thread> server_thread;
std::atomic<bool> is_running;
DatabaseInstance* db_instance;

unique_ptr<Allocator> allocator;

HttpServerState() : is_running(false), db_instance(nullptr) {}
};

Expand All @@ -32,16 +35,16 @@ static void HandleQuery(const string& query, duckdb_httplib_openssl::Response& r
if (!global_state.db_instance) {
throw IOException("Database instance not initialized");
}

Connection con(*global_state.db_instance);
auto result = con.Query(query);

if (result->HasError()) {
res.status = 400;
res.set_content(result->GetError(), "text/plain");
return;
}

res.set_content(result->ToString(), "text/plain");
} catch (const Exception& ex) {
res.status = 400;
Expand All @@ -53,19 +56,22 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port) {
if (global_state.is_running) {
throw IOException("HTTP server is already running");
}

global_state.db_instance = &db;
global_state.server.reset(new duckdb_httplib_openssl::Server());
global_state.server = make_uniq<duckdb_httplib_openssl::Server>();
global_state.is_running = true;


// Create a new allocator for the server thread
global_state.allocator = make_uniq<Allocator>();

// Handle GET requests
global_state.server->Get("/query", [](const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
if (!req.has_param("q")) {
res.status = 400;
res.set_content("Missing query parameter 'q'", "text/plain");
return;
}

auto query = req.get_param_value("q");
HandleQuery(query, res);
});
Expand All @@ -86,12 +92,12 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port) {
});

string host_str = host.GetString();
global_state.server_thread.reset(new std::thread([host_str, port]() {
global_state.server_thread = make_uniq<std::thread>([host_str, port]() {
if (!global_state.server->listen(host_str.c_str(), port)) {
global_state.is_running = false;
throw IOException("Failed to start HTTP server on " + host_str + ":" + std::to_string(port));
}
}));
});
}

void HttpServerStop() {
Expand All @@ -104,17 +110,24 @@ void HttpServerStop() {
global_state.server_thread.reset();
global_state.db_instance = nullptr;
global_state.is_running = false;

// Reset the allocator
global_state.allocator.reset();
}
}

static void HttpServerCleanup() {
HttpServerStop();
}

static void LoadInternal(DatabaseInstance &instance) {
auto httpserve_start = ScalarFunction("httpserve_start",
{LogicalType::VARCHAR, LogicalType::INTEGER},
LogicalType::VARCHAR,
[&](DataChunk &args, ExpressionState &state, Vector &result) {
auto &host_vector = args.data[0];
auto &port_vector = args.data[1];

UnaryExecutor::Execute<string_t, string_t>(
host_vector, result, args.size(),
[&](string_t host) {
Expand All @@ -134,6 +147,10 @@ static void LoadInternal(DatabaseInstance &instance) {

ExtensionUtil::RegisterFunction(instance, httpserve_start);
ExtensionUtil::RegisterFunction(instance, httpserve_stop);

// Register the cleanup function to be called at exit
std::atexit(HttpServerCleanup);

}

void HttpserverExtension::Load(DuckDB &db) {
Expand Down

0 comments on commit d78d875

Please sign in to comment.