@@ -729,24 +729,30 @@ struct server_task_result_embd : server_task_result {
729
729
int index = 0 ;
730
730
std::vector<std::vector<float >> embedding;
731
731
732
+ // OAI-compat fields
733
+ bool oaicompat = false ;
734
+
732
735
virtual int get_index () override {
733
736
return index;
734
737
}
735
738
736
739
virtual json to_json () override {
737
- if (embedding.size () == 1 ) {
738
- // to be OAI compatible
739
- return json {
740
- {" index" , index},
741
- {" embedding" , embedding[0 ]},
742
- };
743
- }
740
+ return oaicompat ? to_json_oaicompat () : to_json_non_oaicompat ();
741
+ }
744
742
743
+ json to_json_non_oaicompat () {
745
744
return json {
746
745
{" index" , index},
747
746
{" embedding" , embedding},
748
747
};
749
748
}
749
+
750
+ json to_json_oaicompat () {
751
+ return json {
752
+ {" index" , index},
753
+ {" embedding" , embedding[0 ]},
754
+ };
755
+ }
750
756
};
751
757
752
758
struct server_task_result_rerank : server_task_result {
@@ -2018,8 +2024,9 @@ struct server_context {
2018
2024
2019
2025
void send_embedding (const server_slot & slot, const llama_batch & batch) {
2020
2026
auto res = std::make_unique<server_task_result_embd>();
2021
- res->id = slot.id_task ;
2022
- res->index = slot.index ;
2027
+ res->id = slot.id_task ;
2028
+ res->index = slot.index ;
2029
+ res->oaicompat = slot.params .oaicompat ;
2023
2030
2024
2031
const int n_embd = llama_n_embd (model);
2025
2032
@@ -3667,14 +3674,17 @@ int main(int argc, char ** argv) {
3667
3674
res_ok (res, data);
3668
3675
};
3669
3676
3670
- const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3677
+ const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat ) {
3671
3678
const json body = json::parse (req.body );
3672
- bool oaicompat = false ;
3679
+
3680
+ if (oaicompat && llama_pooling_type (ctx_server.ctx ) == LLAMA_POOLING_TYPE_NONE) {
3681
+ res_error (res, format_error_response (" Pooling type 'none' is not OAI compatible. Please use a different pooling type" , ERROR_TYPE_INVALID_REQUEST));
3682
+ return ;
3683
+ }
3673
3684
3674
3685
// an input prompt can be a string or a list of tokens (integer)
3675
3686
json prompt;
3676
3687
if (body.count (" input" ) != 0 ) {
3677
- oaicompat = true ;
3678
3688
prompt = body.at (" input" );
3679
3689
} else if (body.count (" content" ) != 0 ) {
3680
3690
// with "content", we only support single prompt
@@ -3691,10 +3701,15 @@ int main(int argc, char ** argv) {
3691
3701
std::vector<server_task> tasks;
3692
3702
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.ctx , prompt, /* add_special */ false , true );
3693
3703
for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
3694
- server_task task = server_task (SERVER_TASK_TYPE_EMBEDDING);
3704
+ server_task task = server_task (SERVER_TASK_TYPE_EMBEDDING);
3705
+
3695
3706
task.id = ctx_server.queue_tasks .get_new_id ();
3696
3707
task.index = i;
3697
3708
task.prompt_tokens = std::move (tokenized_prompts[i]);
3709
+
3710
+ // OAI-compat
3711
+ task.params .oaicompat = oaicompat;;
3712
+
3698
3713
tasks.push_back (task);
3699
3714
}
3700
3715
@@ -3722,12 +3737,18 @@ int main(int argc, char ** argv) {
3722
3737
}
3723
3738
3724
3739
// write JSON response
3725
- json root = oaicompat
3726
- ? format_embeddings_response_oaicompat (body, responses)
3727
- : responses.size () == 1 ? responses[0 ] : json (responses);
3740
+ json root = oaicompat ? format_embeddings_response_oaicompat (body, responses) : json (responses);
3728
3741
res_ok (res, root);
3729
3742
};
3730
3743
3744
+ const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3745
+ handle_embeddings_impl (req, res, false );
3746
+ };
3747
+
3748
+ const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3749
+ handle_embeddings_impl (req, res, true );
3750
+ };
3751
+
3731
3752
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3732
3753
if (!ctx_server.params_base .reranking || ctx_server.params_base .embedding ) {
3733
3754
res_error (res, format_error_response (" This server does not support reranking. Start it with `--reranking` and without `--embedding`" , ERROR_TYPE_NOT_SUPPORTED));
@@ -3901,7 +3922,7 @@ int main(int argc, char ** argv) {
3901
3922
svr->Post (" /infill" , handle_infill);
3902
3923
svr->Post (" /embedding" , handle_embeddings); // legacy
3903
3924
svr->Post (" /embeddings" , handle_embeddings);
3904
- svr->Post (" /v1/embeddings" , handle_embeddings );
3925
+ svr->Post (" /v1/embeddings" , handle_embeddings_oai );
3905
3926
svr->Post (" /rerank" , handle_rerank);
3906
3927
svr->Post (" /reranking" , handle_rerank);
3907
3928
svr->Post (" /v1/rerank" , handle_rerank);
0 commit comments