Skip to content

Commit

Permalink
Merge pull request #363 from janhq/358-feat-improvement-over-nitro-qu…
Browse files Browse the repository at this point in the history
…eue-system

358 feat improvement over nitro queue system
  • Loading branch information
tikikun authored Jan 18, 2024
2 parents 902dc3f + ebdf420 commit 1861e43
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
using namespace inferences;
using json = nlohmann::json;

struct State {
bool isStopped = false;
struct inferenceState {
bool is_stopped = false;
bool is_streaming = false;
int task_id;
llamaCPP *instance;

State(int tid, llamaCPP *inst) : task_id(tid), instance(inst) {}
inferenceState(llamaCPP *inst) : instance(inst) {}
};

std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {
return std::make_shared<State>(task_id, instance);
std::shared_ptr<inferenceState> create_inference_state(llamaCPP *instance) {
return std::make_shared<inferenceState>(instance);
}

// --------------------------------------------
Expand Down Expand Up @@ -295,41 +296,35 @@ void llamaCPP::chatCompletion(
#endif
int task_id;

if (llama.params.n_parallel == 1) {
while (true) {
if (!single_queue_is_busy) {
task_id = llama.request_completion(data, false, false, -1);
single_queue_is_busy = true;
break;
} else {
std::this_thread::sleep_for(
std::chrono::milliseconds(500)); // Sleep for 500 milliseconds
}
}
} else {
task_id = llama.request_completion(data, false, false, -1);
}

LOG_INFO << "Resolved request for task_id:" << task_id;

if (is_streamed) {
auto state = createState(task_id, this);

auto state = create_inference_state(this);
state->task_id = task_id;
auto chunked_content_provider =
[this, state](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
[state, data](char *pBuffer, std::size_t nBuffSize) -> std::size_t {
if (!state->is_streaming) {
state->task_id =
state->instance->llama.request_completion(data, false, false, -1);
state->instance->single_queue_is_busy = true;
}
if (!pBuffer) {
LOG_INFO << "Connection closed or buffer is null. Reset context";
state->instance->llama.request_cancel(state->task_id);
single_queue_is_busy = false;
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
}
if (state->isStopped) {
single_queue_is_busy = false;
if (state->is_stopped) {
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
}

task_result result = state->instance->llama.next_result(state->task_id);
if (!result.error) {
// Update streaming state to being streamed
state->is_streaming = true;
const std::string to_send = result.result_json["content"];
const std::string str =
"data: " +
Expand All @@ -351,16 +346,30 @@ void llamaCPP::chatCompletion(
std::size_t nRead = std::min(str.size(), nBuffSize);
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "reached result stop";
state->isStopped = true;
state->is_stopped = true;
state->instance->llama.request_cancel(state->task_id);
state->is_streaming = false;
state->instance->single_queue_is_busy = false;

return nRead;
}
return nRead;
} else {
single_queue_is_busy = false;
return 0;
if (state->instance->llama.params.n_parallel == 1) {
while (state->instance->single_queue_is_busy) {
LOG_INFO << "Waiting for task to be released status:"
<< state->instance->single_queue_is_busy;
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // Waiting in 500 miliseconds step
}
}
std::string str = "\n\n";
std::size_t nRead = str.size();
memcpy(pBuffer, str.data(), nRead);
LOG_INFO << "Failing retrying now";
return nRead;
}
single_queue_is_busy = false;
state->is_streaming = false;
state->instance->single_queue_is_busy = false;
return 0;
};
auto resp = nitro_utils::nitroStreamResponse(chunked_content_provider,
Expand Down

0 comments on commit 1861e43

Please sign in to comment.