diff --git a/src/httpserver_extension.cpp b/src/httpserver_extension.cpp index 65a9aa0..8e970e6 100644 --- a/src/httpserver_extension.cpp +++ b/src/httpserver_extension.cpp @@ -40,6 +40,13 @@ struct HttpServerState { static HttpServerState global_state; +struct HttpServerException: public std::exception { + int status; + std::string message; + + HttpServerException(int status, const std::string& message) : message(message), status(status) {} +}; + std::string GetColumnType(MaterializedQueryResult &result, idx_t column) { if (result.RowCount() == 0) { return "String"; @@ -152,16 +159,16 @@ std::string base64_decode(const std::string &in) { return out; } -// Auth Check -bool IsAuthenticated(const duckdb_httplib_openssl::Request& req) { +// Check authentication +void CheckAuthentication(const duckdb_httplib_openssl::Request& req) { if (global_state.auth_token.empty()) { - return true; // No authentication required if no token is set + return; // No authentication required if no token is set } // Check for X-API-Key header auto api_key = req.get_header_value("X-API-Key"); if (!api_key.empty() && api_key == global_state.auth_token) { - return true; + return; } // Check for Basic Auth @@ -169,11 +176,11 @@ bool IsAuthenticated(const duckdb_httplib_openssl::Request& req) { if (!auth.empty() && auth.compare(0, 6, "Basic ") == 0) { std::string decoded_auth = base64_decode(auth.substr(6)); if (decoded_auth == global_state.auth_token) { - return true; + return; } } - return false; + throw HttpServerException(401, "Unauthorized"); } // Convert the query result to NDJSON (JSONEachRow) format @@ -217,49 +224,131 @@ static std::string ConvertResultToNDJSON(MaterializedQueryResult &result) { return ndjson_output; } -// Handle both GET and POST requests -void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) { - std::string query; +BoundParameterData ExtractQueryParameter(const std::string& key, yyjson_val* parameterVal) { + if (!yyjson_is_obj(parameterVal)) { + throw HttpServerException(400, "The parameter `" + key + "` parameter must be an object"); + } - // Check authentication - if (!IsAuthenticated(req)) { - res.status = 401; - res.set_content("Unauthorized", "text/plain"); - return; + auto typeVal = yyjson_obj_get(parameterVal, "type"); + if (!typeVal) { + throw HttpServerException(400, "The parameter `" + key + "` does not have a `type` field"); + } + if (!yyjson_is_str(typeVal)) { + throw HttpServerException(400, "The field `type` for the parameter `" + key + "` must be a string"); + } + auto type = std::string(yyjson_get_str(typeVal)); + + auto valueVal = yyjson_obj_get(parameterVal, "value"); + if (!valueVal) { + throw HttpServerException(400, "The parameter `" + key + "` does not have a `value` field"); } - // CORS allow - res.set_header("Access-Control-Allow-Origin", "*"); - res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT"); - res.set_header("Access-Control-Allow-Headers", "*"); - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Max-Age", "86400"); + if (type == "TEXT") { + if (!yyjson_is_str(valueVal)) { + throw HttpServerException(400, "The field `value` for the parameter `" + key + "` must be a string"); + } - // Handle preflight OPTIONS request - if (req.method == "OPTIONS") { - res.status = 204; // No content - return; + return BoundParameterData(Value(yyjson_get_str(valueVal))); + } + else if (type == "BOOLEAN") { + if (!yyjson_is_bool(valueVal)) { + throw HttpServerException(400, "The field `value` for the parameter `" + key + "` must be a boolean"); + } + + return BoundParameterData(Value(bool(yyjson_get_bool(valueVal)))); + } + + throw HttpServerException(400, "Unsupported type " + type + " the parameter `" + key + "`"); +} + +case_insensitive_map_t ExtractQueryParameters(yyjson_doc* parametersDoc) { + if (!parametersDoc) { + throw HttpServerException(400, "Unable to parse the `parameters` parameter"); + } + + auto parametersRoot = yyjson_doc_get_root(parametersDoc); + if (!yyjson_is_obj(parametersRoot)) { + throw HttpServerException(400, "The `parameters` parameter must be an object"); + } + + case_insensitive_map_t named_values; + + size_t idx, max; + yyjson_val *parameterKeyVal, *parameterVal; + yyjson_obj_foreach(parametersRoot, idx, max, parameterKeyVal, parameterVal) { + auto parameterKeyString = std::string(yyjson_get_str(parameterKeyVal)); + + named_values[parameterKeyString] = ExtractQueryParameter(parameterKeyString, parameterVal); + } + + return named_values; +} + +case_insensitive_map_t ExtractQueryParametersWrapper(const duckdb_httplib_openssl::Request& req) { + yyjson_doc *parametersDoc = nullptr; + + try { + auto parametersJson = req.get_param_value("parameters"); + auto parametersJsonCStr = parametersJson.c_str(); + parametersDoc = yyjson_read(parametersJsonCStr, strlen(parametersJsonCStr), 0); + return ExtractQueryParameters(parametersDoc); + } + catch (const Exception& exception) { + yyjson_doc_free(parametersDoc); + + throw exception; + } +} + +// Execute query (optionally using a prepared statement) +std::unique_ptr ExecuteQuery( + const duckdb_httplib_openssl::Request& req, + const std::string& query +) { + Connection con(*global_state.db_instance); + std::unique_ptr result; + + if (req.has_param("parameters")) { + auto prepared_stmt = con.Prepare(query); + if (prepared_stmt->HasError()) { + throw HttpServerException(500, prepared_stmt->GetError()); + } + + auto named_values = ExtractQueryParametersWrapper(req); + + auto prepared_stmt_result = prepared_stmt->Execute(named_values); + D_ASSERT(prepared_stmt_result->type == QueryResultType::STREAM_RESULT); + result = unique_ptr_cast(std::move(prepared_stmt_result))->Materialize(); + } else { + result = con.Query(query); } + if (result->HasError()) { + throw HttpServerException(500, result->GetError()); + } + + return result; +} + +std::string ExtractQuery(const duckdb_httplib_openssl::Request& req) { // Check if the query is in the URL parameters if (req.has_param("query")) { - query = req.get_param_value("query"); + return req.get_param_value("query"); } else if (req.has_param("q")) { - query = req.get_param_value("q"); + return req.get_param_value("q"); } + // If not in URL, and it's a POST request, check the body else if (req.method == "POST" && !req.body.empty()) { - query = req.body; - } - // If no query found, return an error - else { - res.status = 200; - res.set_content(reinterpret_cast(playgroundContent), "text/html"); - return; + return req.body; } - // Set default format to JSONCompact + // std::optional is not available for this project + return ""; +} + +std::string ExtractFormat(const duckdb_httplib_openssl::Request& req) { std::string format = "JSONEachRow"; // Check for format in URL parameter or header @@ -271,24 +360,45 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli format = req.get_header_value("format"); } + return format; +} + +// Handle both GET and POST requests +void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) { try { + CheckAuthentication(req); + + // CORS allow + res.set_header("Access-Control-Allow-Origin", "*"); + res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Max-Age", "86400"); + + // Handle preflight OPTIONS request + if (req.method == "OPTIONS") { + res.status = 204; // No content + return; + } + + auto query = ExtractQuery(req); + auto format = ExtractFormat(req); + + if (query == "") { + res.status = 200; + res.set_content(reinterpret_cast(playgroundContent), sizeof(playgroundContent), "text/html"); + return; + } + if (!global_state.db_instance) { throw IOException("Database instance not initialized"); } - Connection con(*global_state.db_instance); auto start = std::chrono::system_clock::now(); - auto result = con.Query(query); + auto result = ExecuteQuery(req, query); auto end = std::chrono::system_clock::now(); auto elapsed = std::chrono::duration_cast(end - start); - if (result->HasError()) { - res.status = 500; - res.set_content(result->GetError(), "text/plain"); - return; - } - - ReqStats stats{ static_cast(elapsed.count()) / 1000, 0, @@ -308,7 +418,12 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli res.set_content(json_output, "application/x-ndjson"); } - } catch (const Exception& ex) { + } + catch (const HttpServerException& ex) { + res.status = ex.status; + res.set_content(ex.message, "text/plain"); + } + catch (const Exception& ex) { res.status = 500; std::string error_message = "Code: 59, e.displayText() = DB::Exception: " + std::string(ex.what()); res.set_content(error_message, "text/plain"); @@ -325,9 +440,9 @@ void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port, string_t global_state.is_running = true; global_state.auth_token = auth.GetString(); - // Custom basepath, defaults to root / + // Custom basepath, defaults to root / const char* base_path_env = std::getenv("DUCKDB_HTTPSERVER_BASEPATH"); - std::string base_path = "/"; + std::string base_path = "/"; if (base_path_env && base_path_env[0] == '/' && strlen(base_path_env) > 1) { base_path = std::string(base_path_env); diff --git a/test/sql/auth.test b/test/sql/auth.test new file mode 100644 index 0000000..7f863a0 --- /dev/null +++ b/test/sql/auth.test @@ -0,0 +1,129 @@ +# name: test/sql/auth.test +# description: test httpserver extension +# group: [httpserver] + +################################################################ +# Setup +################################################################ + +require httpserver + +statement ok +INSTALL http_client FROM community; + +statement ok +LOAD http_client; + +statement ok +INSTALL json; + +statement ok +LOAD json; + +################################################################ +# No auth test +################################################################ + +query I +SELECT httpserve_start('127.0.0.1', 4000, ''); +---- +HTTP server started on 127.0.0.1:4000 + +query TTT +WITH response AS (SELECT http_post('http://127.0.0.1:4000/?q=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n + +query I +SELECT httpserve_stop(); +---- +HTTP server stopped + +################################################################ +# Basic auth test +################################################################ + +query I +SELECT httpserve_start('127.0.0.1', 4000, 'bob:pwd'); +---- +HTTP server started on 127.0.0.1:4000 + +query TTT +WITH response AS (SELECT http_post('http://127.0.0.1:4000/?q=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +401 Unauthorized Unauthorized + +query TTT +WITH response AS ( + SELECT http_post( + 'http://127.0.0.1:4000/?q=SELECT 123', + MAP { + 'Authorization': CONCAT('Basic ', TO_BASE64('bob:pwd'::BLOB)), + }, + MAP {} + ) response +) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n + +query I +SELECT httpserve_stop(); +---- +HTTP server stopped + +################################################################ +# Token test +################################################################ + +query I +SELECT httpserve_start('127.0.0.1', 4000, 'my-api-key'); +---- +HTTP server started on 127.0.0.1:4000 + +query TTT +WITH response AS (SELECT http_post('http://127.0.0.1:4000/?q=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +401 Unauthorized Unauthorized + +query TTT +WITH response AS ( + SELECT http_post( + 'http://127.0.0.1:4000/?q=SELECT 123', + MAP { + 'X-API-Key': 'my-api-key', + }, + MAP {} + ) response +) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n + +query I +SELECT httpserve_stop(); +---- +HTTP server stopped diff --git a/test/sql/basics.test b/test/sql/basics.test new file mode 100644 index 0000000..897dc7c --- /dev/null +++ b/test/sql/basics.test @@ -0,0 +1,69 @@ +# name: test/sql/basics.test +# description: test httpserver extension +# group: [httpserver] + +# Before we load the extension, this will fail +statement error +SELECT httpserve_start('127.0.0.1', 4000, ''); +---- +Catalog Error: Scalar Function with name httpserve_start does not exist! + +# Require statement will ensure this test is run with this extension loaded +require httpserver + +statement ok +INSTALL http_client FROM community; + +statement ok +LOAD http_client; + +statement ok +INSTALL json; + +statement ok +LOAD json; + +# The HTTP server is not available yet +query TTT +WITH response AS (SELECT http_post('http://127.0.0.1:4000/abc', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +-1 HTTP POST request failed. Connection error. (empty) + +# Start the HTTP server +query I +SELECT httpserve_start('127.0.0.1', 4000, ''); +---- +HTTP server started on 127.0.0.1:4000 + +# Simple request +query TTT +WITH response AS (SELECT http_post('http://127.0.0.1:4000/?q=SELECT ''World'' AS Hello', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"Hello":"World"}\n + +# Stop the HTTP server +query I +SELECT httpserve_stop(); +---- +HTTP server stopped + +# The HTTP server is not available anymore +query TTT +WITH response AS (SELECT http_post('http://127.0.0.1:4000/?q=SELECT ''World'' AS Hello', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +-1 HTTP POST request failed. Connection error. (empty) diff --git a/test/sql/quack.test b/test/sql/quack.test deleted file mode 100644 index 519a354..0000000 --- a/test/sql/quack.test +++ /dev/null @@ -1,23 +0,0 @@ -# name: test/sql/quack.test -# description: test quack extension -# group: [quack] - -# Before we load the extension, this will fail -statement error -SELECT quack('Sam'); ----- -Catalog Error: Scalar Function with name quack does not exist! - -# Require statement will ensure this test is run with this extension loaded -require quack - -# Confirm the extension works -query I -SELECT quack('Sam'); ----- -Quack Sam 🐥 - -query I -SELECT quack_openssl_version('Michael') ILIKE 'Quack Michael, my linked OpenSSL version is OpenSSL%'; ----- -true diff --git a/test/sql/simple-get.test b/test/sql/simple-get.test new file mode 100644 index 0000000..3cbaf59 --- /dev/null +++ b/test/sql/simple-get.test @@ -0,0 +1,52 @@ +# name: test/sql/simple-get.test +# description: test httpserver extension +# group: [httpserver] + +################################################################ +# Setup +################################################################ + +require httpserver + +statement ok +INSTALL http_client FROM community; + +statement ok +LOAD http_client; + +statement ok +INSTALL json; + +statement ok +LOAD json; + +query I +SELECT httpserve_start('127.0.0.1', 4000, ''); +---- +HTTP server started on 127.0.0.1:4000 + +################################################################ +# Tests +################################################################ + +# SQL request in `q` parameter +query TTT +WITH response AS (SELECT http_post('http://127.0.0.1:4000/?q=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n + +# SQL request in `query` parameter +query TTT +WITH response AS (SELECT http_post('http://127.0.0.1:4000/?query=SELECT 123', MAP {}, MAP {}) response) +SELECT + response->>'status', + response->>'reason', + regexp_replace(response->>'body', '[\r\n]+', '\\n') +FROM response; +---- +200 OK {"123":"123"}\n