Skip to content

Commit

Permalink
Merge pull request #435 from janhq/429-feat-properly-decoupling-cors-…
Browse files Browse the repository at this point in the history
…and-handleprelight-into-a-seperated-controller

429 feat properly decoupling cors and handleprelight into a seperated controller
  • Loading branch information
tikikun authored Feb 16, 2024
2 parents 627a597 + ab408c9 commit 5fdbdd4
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 40 deletions.
75 changes: 37 additions & 38 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ using json = nlohmann::json;
/**
* The state of the inference task
*/
enum InferenceStatus {
PENDING,
RUNNING,
FINISHED
};
enum InferenceStatus { PENDING, RUNNING, FINISHED };

/**
* There is a need to save state of current ongoing inference status of a
Expand Down Expand Up @@ -141,7 +137,9 @@ std::string create_return_json(const std::string &id, const std::string &model,
return Json::writeString(writer, root);
}

llamaCPP::llamaCPP(): queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel, "llamaCPP")) {
llamaCPP::llamaCPP()
: queue(new trantor::ConcurrentTaskQueue(llama.params.n_parallel,
"llamaCPP")) {
// Some default values for now below
log_disable(); // Disable the log to file feature, reduce bloat for
// target
Expand Down Expand Up @@ -172,7 +170,7 @@ void llamaCPP::inference(

const auto &jsonBody = req->getJsonObject();
// Check if model is loaded
if(checkModelLoaded(callback)) {
if (checkModelLoaded(callback)) {
// Model is loaded
// Do Inference
inferenceImpl(jsonBody, callback);
Expand Down Expand Up @@ -329,8 +327,7 @@ void llamaCPP::inferenceImpl(
auto state = create_inference_state(this);
auto chunked_content_provider =
[state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {

if(state->inferenceStatus == PENDING) {
if (state->inferenceStatus == PENDING) {
state->inferenceStatus = RUNNING;
} else if (state->inferenceStatus == FINISHED) {
return 0;
Expand All @@ -341,7 +338,7 @@ void llamaCPP::inferenceImpl(
state->inferenceStatus = FINISHED;
return 0;
}

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
const std::string to_send = result.result_json["content"];
Expand All @@ -367,10 +364,10 @@ void llamaCPP::inferenceImpl(
LOG_INFO << "reached result stop";
state->inferenceStatus = FINISHED;
}

// Make sure nBufferSize is not zero
// Otherwise it stop streaming
if(!nRead) {
if (!nRead) {
state->inferenceStatus = FINISHED;
}

Expand All @@ -380,31 +377,33 @@ void llamaCPP::inferenceImpl(
return 0;
};
// Queued task
state->instance->queue->runTaskInQueue([callback, state, data,
chunked_content_provider]() {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);

// Start streaming response
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);

int retries = 0;

// Since this is an async task, we will wait for the task to be completed
while (state->inferenceStatus != FINISHED && retries < 10) {
// Should wait chunked_content_provider lambda to be called within 3s
if(state->inferenceStatus == PENDING) {
retries += 1;
}
if(state->inferenceStatus != RUNNING)
LOG_INFO << "Wait for task to be released:" << state->task_id;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
// Request completed, release it
state->instance->llama.request_cancel(state->task_id);
});
state->instance->queue->runTaskInQueue(
[callback, state, data, chunked_content_provider]() {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);

// Start streaming response
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
"chat_completions.txt");
callback(resp);

int retries = 0;

// Since this is an async task, we will wait for the task to be
// completed
while (state->inferenceStatus != FINISHED && retries < 10) {
// Should wait chunked_content_provider lambda to be called within
// 3s
if (state->inferenceStatus == PENDING) {
retries += 1;
}
if (state->inferenceStatus != RUNNING)
LOG_INFO << "Wait for task to be released:" << state->task_id;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
// Request completed, release it
state->instance->llama.request_cancel(state->task_id);
});
} else {
Json::Value respData;
auto resp = nitro_utils::nitroHttpResponse();
Expand Down Expand Up @@ -434,7 +433,7 @@ void llamaCPP::embedding(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
// Check if model is loaded
if(checkModelLoaded(callback)) {
if (checkModelLoaded(callback)) {
// Model is loaded
const auto &jsonBody = req->getJsonObject();
// Run embedding
Expand Down
5 changes: 3 additions & 2 deletions controllers/llamaCPP.h
Original file line number Diff line number Diff line change
Expand Up @@ -2526,10 +2526,11 @@ class llamaCPP : public drogon::HttpController<llamaCPP>, public ChatProvider {

// Openai compatible path
ADD_METHOD_TO(llamaCPP::inference, "/v1/chat/completions", Post);
// ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/chat/completions", Options); NOTE: prelight will be added back when browser support is properly planned
// ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/chat/completions", Options);
// NOTE: prelight will be added back when browser support is properly planned

ADD_METHOD_TO(llamaCPP::embedding, "/v1/embeddings", Post);
//ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/embeddings", Options);
// ADD_METHOD_TO(llamaCPP::handlePrelight, "/v1/embeddings", Options);

// PATH_ADD("/llama/chat_completion", Post);
METHOD_LIST_END
Expand Down
13 changes: 13 additions & 0 deletions controllers/prelight.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "prelight.h"

void prelight::handlePrelight(
const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback) {
auto resp = drogon::HttpResponse::newHttpResponse();
resp->setStatusCode(drogon::HttpStatusCode::k200OK);
resp->addHeader("Access-Control-Allow-Origin", "*");
resp->addHeader("Access-Control-Allow-Methods", "POST, OPTIONS");
resp->addHeader("Access-Control-Allow-Headers", "*");
callback(resp);
}

18 changes: 18 additions & 0 deletions controllers/prelight.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include <drogon/HttpController.h>

using namespace drogon;

class prelight : public drogon::HttpController<prelight> {
public:
METHOD_LIST_BEGIN
ADD_METHOD_TO(prelight::handlePrelight, "/v1/chat/completions", Options);
ADD_METHOD_TO(prelight::handlePrelight, "/v1/embeddings", Options);
ADD_METHOD_TO(prelight::handlePrelight, "/v1/audio/transcriptions", Options);
ADD_METHOD_TO(prelight::handlePrelight, "/v1/audio/translations", Options);
METHOD_LIST_END

void handlePrelight(const HttpRequestPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback);
};

0 comments on commit 5fdbdd4

Please sign in to comment.