Skip to content

Commit 01e2e3d

Browse files
committed
wip reranking llama.cpp
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 61cc76c commit 01e2e3d

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

backend/cpp/llama/grpc-server.cpp

+60-2
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ struct llama_client_slot
217217

218218
bool infill = false;
219219
bool embedding = false;
220+
bool reranker = false;
220221
bool has_next_token = true;
221222
bool truncated = false;
222223
bool stopped_eos = false;
@@ -1413,14 +1414,62 @@ struct llama_server_context
14131414
queue_results.send(res);
14141415
}
14151416

1416-
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
1417+
void send_rerank(llama_client_slot &slot, const llama_batch & batch)
1418+
{
1419+
task_result res;
1420+
res.id = slot.task_id;
1421+
res.multitask_id = slot.multitask_id;
1422+
res.error = false;
1423+
res.stop = true;
1424+
1425+
float score = -1e6f; // Default score if we fail to get embeddings
1426+
1427+
if (!params.rerank)
1428+
{
1429+
LOG_WARNING("reranking disabled", {
1430+
{"params.rerank", params.rerank},
1431+
});
1432+
}
1433+
else
1434+
{
1435+
for (int i = 0; i < batch.n_tokens; ++i) {
1436+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1437+
continue;
1438+
}
1439+
1440+
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1441+
if (embd == NULL) {
1442+
embd = llama_get_embeddings_ith(ctx, i);
1443+
}
1444+
1445+
if (embd == NULL) {
1446+
LOG("failed to get embeddings");
1447+
continue;
1448+
}
1449+
1450+
score = embd[0];
1451+
}
1452+
}
1453+
1454+
// Format result as JSON similar to the embedding function
1455+
res.result_json = json
1456+
{
1457+
{"score", score},
1458+
{"tokens", slot.n_prompt_tokens}
1459+
};
1460+
1461+
queue_results.send(res);
1462+
}
1463+
1464+
void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id)
14171465
{
14181466
task_server task;
14191467
task.id = task_id;
14201468
task.target_id = 0;
14211469
task.data = std::move(data);
14221470
task.infill_mode = infill;
14231471
task.embedding_mode = embedding;
1472+
task.reranking_mode = rerank;
14241473
task.type = TASK_TYPE_COMPLETION;
14251474
task.multitask_id = multitask_id;
14261475

@@ -1552,7 +1601,7 @@ struct llama_server_context
15521601
subtask_data["prompt"] = subtask_data["prompt"][i];
15531602

15541603
// subtasks inherit everything else (infill mode, embedding mode, etc.)
1555-
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
1604+
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id);
15561605
}
15571606
}
15581607

@@ -1591,6 +1640,7 @@ struct llama_server_context
15911640

15921641
slot->infill = task.infill_mode;
15931642
slot->embedding = task.embedding_mode;
1643+
slot->reranker = task.reranking_mode;
15941644
slot->task_id = task.id;
15951645
slot->multitask_id = task.multitask_id;
15961646

@@ -2034,6 +2084,14 @@ struct llama_server_context
20342084
continue;
20352085
}
20362086

2087+
if (slot.reranker)
2088+
{
2089+
send_rerank(slot, batch_view);
2090+
slot.release();
2091+
slot.i_batch = -1;
2092+
continue;
2093+
}
2094+
20372095
completion_token_output result;
20382096
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
20392097

backend/cpp/llama/utils.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ struct task_server {
6161
json data;
6262
bool infill_mode = false;
6363
bool embedding_mode = false;
64+
bool reranking_mode = false;
6465
int multitask_id = -1;
6566
};
6667

0 commit comments

Comments
 (0)