@@ -217,6 +217,7 @@ struct llama_client_slot
217
217
218
218
bool infill = false ;
219
219
bool embedding = false ;
220
+ bool reranker = false ;
220
221
bool has_next_token = true ;
221
222
bool truncated = false ;
222
223
bool stopped_eos = false ;
@@ -1413,14 +1414,62 @@ struct llama_server_context
1413
1414
queue_results.send (res);
1414
1415
}
1415
1416
1416
- void request_completion (int task_id, json data, bool infill, bool embedding, int multitask_id)
1417
+ void send_rerank (llama_client_slot &slot, const llama_batch & batch)
1418
+ {
1419
+ task_result res;
1420
+ res.id = slot.task_id ;
1421
+ res.multitask_id = slot.multitask_id ;
1422
+ res.error = false ;
1423
+ res.stop = true ;
1424
+
1425
+ float score = -1e6f; // Default score if we fail to get embeddings
1426
+
1427
+ if (!params.rerank )
1428
+ {
1429
+ LOG_WARNING (" reranking disabled" , {
1430
+ {" params.rerank" , params.rerank },
1431
+ });
1432
+ }
1433
+ else
1434
+ {
1435
+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1436
+ if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id ) {
1437
+ continue ;
1438
+ }
1439
+
1440
+ const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
1441
+ if (embd == NULL ) {
1442
+ embd = llama_get_embeddings_ith (ctx, i);
1443
+ }
1444
+
1445
+ if (embd == NULL ) {
1446
+ LOG (" failed to get embeddings" );
1447
+ continue ;
1448
+ }
1449
+
1450
+ score = embd[0 ];
1451
+ }
1452
+ }
1453
+
1454
+ // Format result as JSON similar to the embedding function
1455
+ res.result_json = json
1456
+ {
1457
+ {" score" , score},
1458
+ {" tokens" , slot.n_prompt_tokens }
1459
+ };
1460
+
1461
+ queue_results.send (res);
1462
+ }
1463
+
1464
+ void request_completion (int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id)
1417
1465
{
1418
1466
task_server task;
1419
1467
task.id = task_id;
1420
1468
task.target_id = 0 ;
1421
1469
task.data = std::move (data);
1422
1470
task.infill_mode = infill;
1423
1471
task.embedding_mode = embedding;
1472
+ task.reranking_mode = rerank;
1424
1473
task.type = TASK_TYPE_COMPLETION;
1425
1474
task.multitask_id = multitask_id;
1426
1475
@@ -1552,7 +1601,7 @@ struct llama_server_context
1552
1601
subtask_data[" prompt" ] = subtask_data[" prompt" ][i];
1553
1602
1554
1603
// subtasks inherit everything else (infill mode, embedding mode, etc.)
1555
- request_completion (subtask_ids[i], subtask_data, multiprompt_task.infill_mode , multiprompt_task.embedding_mode , multitask_id);
1604
+ request_completion (subtask_ids[i], subtask_data, multiprompt_task.infill_mode , multiprompt_task.embedding_mode , multiprompt_task. reranking_mode , multitask_id);
1556
1605
}
1557
1606
}
1558
1607
@@ -1591,6 +1640,7 @@ struct llama_server_context
1591
1640
1592
1641
slot->infill = task.infill_mode ;
1593
1642
slot->embedding = task.embedding_mode ;
1643
+ slot->reranker = task.reranking_mode ;
1594
1644
slot->task_id = task.id ;
1595
1645
slot->multitask_id = task.multitask_id ;
1596
1646
@@ -2034,6 +2084,14 @@ struct llama_server_context
2034
2084
continue ;
2035
2085
}
2036
2086
2087
+ if (slot.reranker )
2088
+ {
2089
+ send_rerank (slot, batch_view);
2090
+ slot.release ();
2091
+ slot.i_batch = -1 ;
2092
+ continue ;
2093
+ }
2094
+
2037
2095
completion_token_output result;
2038
2096
const llama_token id = common_sampler_sample (slot.ctx_sampling , ctx, slot.i_batch - i);
2039
2097
0 commit comments