Skip to content

Commit c63d869

Browse files
committed
server : update /embeddings and /v1/embeddings endpoints
ggml-ci
1 parent b197797 commit c63d869

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

examples/server/server.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -729,24 +729,30 @@ struct server_task_result_embd : server_task_result {
729729
int index = 0;
730730
std::vector<std::vector<float>> embedding;
731731

732+
// OAI-compat fields
733+
bool oaicompat = false;
734+
732735
virtual int get_index() override {
733736
return index;
734737
}
735738

736739
virtual json to_json() override {
737-
if (embedding.size() == 1) {
738-
// to be OAI compatible
739-
return json {
740-
{"index", index},
741-
{"embedding", embedding[0]},
742-
};
743-
}
740+
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
741+
}
744742

743+
json to_json_non_oaicompat() {
745744
return json {
746745
{"index", index},
747746
{"embedding", embedding},
748747
};
749748
}
749+
750+
json to_json_oaicompat() {
751+
return json {
752+
{"index", index},
753+
{"embedding", embedding[0]},
754+
};
755+
}
750756
};
751757

752758
struct server_task_result_rerank : server_task_result {
@@ -2018,8 +2024,9 @@ struct server_context {
20182024

20192025
void send_embedding(const server_slot & slot, const llama_batch & batch) {
20202026
auto res = std::make_unique<server_task_result_embd>();
2021-
res->id = slot.id_task;
2022-
res->index = slot.index;
2027+
res->id = slot.id_task;
2028+
res->index = slot.index;
2029+
res->oaicompat = slot.params.oaicompat;
20232030

20242031
const int n_embd = llama_n_embd(model);
20252032

@@ -3667,14 +3674,17 @@ int main(int argc, char ** argv) {
36673674
res_ok(res, data);
36683675
};
36693676

3670-
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3677+
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
36713678
const json body = json::parse(req.body);
3672-
bool oaicompat = false;
3679+
3680+
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
3681+
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
3682+
return;
3683+
}
36733684

36743685
// an input prompt can be a string or a list of tokens (integer)
36753686
json prompt;
36763687
if (body.count("input") != 0) {
3677-
oaicompat = true;
36783688
prompt = body.at("input");
36793689
} else if (body.count("content") != 0) {
36803690
// with "content", we only support single prompt
@@ -3691,10 +3701,15 @@ int main(int argc, char ** argv) {
36913701
std::vector<server_task> tasks;
36923702
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
36933703
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3694-
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
3704+
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
3705+
36953706
task.id = ctx_server.queue_tasks.get_new_id();
36963707
task.index = i;
36973708
task.prompt_tokens = std::move(tokenized_prompts[i]);
3709+
3710+
// OAI-compat
3711+
task.params.oaicompat = oaicompat;;
3712+
36983713
tasks.push_back(task);
36993714
}
37003715

@@ -3722,12 +3737,18 @@ int main(int argc, char ** argv) {
37223737
}
37233738

37243739
// write JSON response
3725-
json root = oaicompat
3726-
? format_embeddings_response_oaicompat(body, responses)
3727-
: responses.size() == 1 ? responses[0] : json(responses);
3740+
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
37283741
res_ok(res, root);
37293742
};
37303743

3744+
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3745+
handle_embeddings_impl(req, res, false);
3746+
};
3747+
3748+
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3749+
handle_embeddings_impl(req, res, true);
3750+
};
3751+
37313752
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
37323753
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
37333754
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
@@ -3901,7 +3922,7 @@ int main(int argc, char ** argv) {
39013922
svr->Post("/infill", handle_infill);
39023923
svr->Post("/embedding", handle_embeddings); // legacy
39033924
svr->Post("/embeddings", handle_embeddings);
3904-
svr->Post("/v1/embeddings", handle_embeddings);
3925+
svr->Post("/v1/embeddings", handle_embeddings_oai);
39053926
svr->Post("/rerank", handle_rerank);
39063927
svr->Post("/reranking", handle_rerank);
39073928
svr->Post("/v1/rerank", handle_rerank);

examples/server/tests/unit/test_embedding.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_embedding_single():
1616
global server
1717
server.pooling = 'last'
1818
server.start()
19-
res = server.make_request("POST", "/embeddings", data={
19+
res = server.make_request("POST", "/v1/embeddings", data={
2020
"input": "I believe the meaning of life is",
2121
})
2222
assert res.status_code == 200
@@ -32,7 +32,7 @@ def test_embedding_multiple():
3232
global server
3333
server.pooling = 'last'
3434
server.start()
35-
res = server.make_request("POST", "/embeddings", data={
35+
res = server.make_request("POST", "/v1/embeddings", data={
3636
"input": [
3737
"I believe the meaning of life is",
3838
"Write a joke about AI from a very long prompt which will not be truncated",
@@ -55,16 +55,26 @@ def test_embedding_pooling_none():
5555
"input": "hello hello hello",
5656
})
5757
assert res.status_code == 200
58-
assert len(res.body['data']) == 1
59-
assert 'embedding' in res.body['data'][0]
60-
assert len(res.body['data'][0]['embedding']) == 3
58+
assert 'embedding' in res.body[0]
59+
assert len(res.body[0]['embedding']) == 3
60+
61+
62+
def test_embedding_pooling_none_oai():
63+
global server
64+
server.pooling = 'none'
65+
server.start()
66+
res = server.make_request("POST", "/v1/embeddings", data={
67+
"input": "hello hello hello",
68+
})
69+
# /v1/embeddings does not support pooling type 'none'
70+
assert res.status_code == 400
6171

6272

6373
def test_embedding_openai_library_single():
6474
global server
6575
server.pooling = 'last'
6676
server.start()
67-
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
77+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
6878
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
6979
assert len(res.data) == 1
7080
assert len(res.data[0].embedding) > 1
@@ -74,7 +84,7 @@ def test_embedding_openai_library_multiple():
7484
global server
7585
server.pooling = 'last'
7686
server.start()
77-
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
87+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
7888
res = client.embeddings.create(model="text-embedding-3-small", input=[
7989
"I believe the meaning of life is",
8090
"Write a joke about AI from a very long prompt which will not be truncated",
@@ -90,7 +100,7 @@ def test_embedding_error_prompt_too_long():
90100
global server
91101
server.pooling = 'last'
92102
server.start()
93-
res = server.make_request("POST", "/embeddings", data={
103+
res = server.make_request("POST", "/v1/embeddings", data={
94104
"input": "This is a test " * 512,
95105
})
96106
assert res.status_code != 200
@@ -100,7 +110,7 @@ def test_embedding_error_prompt_too_long():
100110
def test_same_prompt_give_same_result():
101111
server.pooling = 'last'
102112
server.start()
103-
res = server.make_request("POST", "/embeddings", data={
113+
res = server.make_request("POST", "/v1/embeddings", data={
104114
"input": [
105115
"I believe the meaning of life is",
106116
"I believe the meaning of life is",

0 commit comments

Comments
 (0)