diff --git a/CMakeLists.txt b/CMakeLists.txt index 9eeef49..1a0b41f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension) project(${TARGET_NAME}) include_directories(src/include duckdb/third_party/httplib duckdb/parquet/include) -set(EXTENSION_SOURCES src/httpserver_extension.cpp) +set(EXTENSION_SOURCES src/httpserver_extension.cpp src/duck_flock.cpp) if(MINGW) set(OPENSSL_USE_STATIC_LIBS TRUE) diff --git a/src/duck_flock.cpp b/src/duck_flock.cpp new file mode 100644 index 0000000..3447c0d --- /dev/null +++ b/src/duck_flock.cpp @@ -0,0 +1,78 @@ +#ifndef DUCK_FLOCK_H +#define DUCK_FLOCK_H +#include "httpserver_extension.hpp" +namespace duckdb { + struct DuckFlockData : FunctionData{ + vector> conn; + vector> results; + unique_ptr Copy() const override { + throw std::runtime_error("not implemented"); + } + bool Equals(const FunctionData &other) const override { + throw std::runtime_error("not implemented"); + }; + }; + + + + unique_ptr DuckFlockBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto data = make_uniq(); + auto strQuery = input.inputs[0].GetValue(); + vector flock; + auto &raw_flock = ListValue::GetChildren(input.inputs[1]); + for (auto &duck : raw_flock) { + flock.push_back(duck.ToString()); + auto conn = make_uniq(*context.db); + conn->Query("INSTALL json;LOAD json;INSTALL httpfs;LOAD httpfs;"); + auto req = conn->Prepare("SELECT * FROM read_json($2 || '/?q=' || url_encode($1::VARCHAR))"); + if (req->HasError()) { + throw std::runtime_error("duck_flock: error: " + req->GetError()); + } + data->conn.push_back(std::move(conn)); + data->results.push_back(std::move(req->Execute(strQuery.c_str(), duck.ToString()))); + } + if (data->results[0]->HasError()) { + throw std::runtime_error("duck_flock: error: " + data->results[0]->GetError()); + } + return_types.clear(); + copy(data->results[0]->types.begin(), data->results[0]->types.end(), back_inserter(return_types)); + names.clear(); + copy(data->results[0]->names.begin(), data->results[0]->names.end(), back_inserter(names)); + return std::move(data); + } + + void DuckFlockImplementation(ClientContext &context, duckdb::TableFunctionInput &data_p, + DataChunk &output) { + auto &data = data_p.bind_data->Cast(); + for (const auto &res : data.results) { + ErrorData error_data; + unique_ptr data_chunk = make_uniq(); + if (res->TryFetch(data_chunk, error_data)) { + if (data_chunk != nullptr) { + output.Append(*data_chunk); + return; + } + } + } + } + + TableFunction DuckFlockTableFunction() { + TableFunction f( + "duck_flock", + {LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR)}, + DuckFlockImplementation, + DuckFlockBind, + nullptr, + nullptr + ); + return f; + } + + +} + + + + +#endif \ No newline at end of file diff --git a/src/httpserver_extension.cpp b/src/httpserver_extension.cpp index 88078f7..14ab49f 100644 --- a/src/httpserver_extension.cpp +++ b/src/httpserver_extension.cpp @@ -77,6 +77,8 @@ static HttpServerState global_state; int64_t read_rows; }; + + // Convert the query result to JSON format static std::string ConvertResultToJSON(MaterializedQueryResult &result, ReqStats &req_stats) { auto doc = yyjson_mut_doc_new(nullptr); @@ -469,7 +471,7 @@ static void LoadInternal(DatabaseInstance &instance) { ExtensionUtil::RegisterFunction(instance, httpserve_start); ExtensionUtil::RegisterFunction(instance, httpserve_stop); - + ExtensionUtil::RegisterFunction(instance, DuckFlockTableFunction()); // Register the cleanup function to be called at exit std::atexit(HttpServerCleanup); } diff --git a/src/include/httpserver_extension.hpp b/src/include/httpserver_extension.hpp index 432d1c0..bad6b38 100644 --- a/src/include/httpserver_extension.hpp +++ b/src/include/httpserver_extension.hpp @@ -17,4 +17,6 @@ struct HttpServerState; void HttpServerStart(DatabaseInstance& db, string_t host, int32_t port); void HttpServerStop(); + TableFunction DuckFlockTableFunction(); + } // namespace duckdb