Skip to content

Commit 05c3a44

Browse files
authored
server : fill usage info in embeddings and rerank responses (#10852)
* server : fill usage info in embeddings response * server : fill usage info in reranking response
1 parent 382bc7f commit 05c3a44

File tree

4 files changed

+77
-10
lines changed

4 files changed

+77
-10
lines changed

examples/server/server.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,14 +719,17 @@ struct server_task_result_embd : server_task_result {
719719
int index = 0;
720720
std::vector<float> embedding;
721721

722+
int32_t n_tokens;
723+
722724
virtual int get_index() override {
723725
return index;
724726
}
725727

726728
virtual json to_json() override {
727729
return json {
728-
{"index", index},
729-
{"embedding", embedding},
730+
{"index", index},
731+
{"embedding", embedding},
732+
{"tokens_evaluated", n_tokens},
730733
};
731734
}
732735
};
@@ -735,14 +738,17 @@ struct server_task_result_rerank : server_task_result {
735738
int index = 0;
736739
float score = -1e6;
737740

741+
int32_t n_tokens;
742+
738743
virtual int get_index() override {
739744
return index;
740745
}
741746

742747
virtual json to_json() override {
743748
return json {
744-
{"index", index},
745-
{"score", score},
749+
{"index", index},
750+
{"score", score},
751+
{"tokens_evaluated", n_tokens},
746752
};
747753
}
748754
};
@@ -1995,6 +2001,7 @@ struct server_context {
19952001
auto res = std::make_unique<server_task_result_embd>();
19962002
res->id = slot.id_task;
19972003
res->index = slot.index;
2004+
res->n_tokens = slot.n_prompt_tokens;
19982005

19992006
const int n_embd = llama_n_embd(model);
20002007

@@ -2030,6 +2037,7 @@ struct server_context {
20302037
auto res = std::make_unique<server_task_result_rerank>();
20312038
res->id = slot.id_task;
20322039
res->index = slot.index;
2040+
res->n_tokens = slot.n_prompt_tokens;
20332041

20342042
for (int i = 0; i < batch.n_tokens; ++i) {
20352043
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {

examples/server/tests/unit/test_embedding.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,33 @@ def test_same_prompt_give_same_result():
9797
vi = res.body['data'][i]['embedding']
9898
for x, y in zip(v0, vi):
9999
assert abs(x - y) < EPSILON
100+
101+
102+
@pytest.mark.parametrize(
103+
"content,n_tokens",
104+
[
105+
("I believe the meaning of life is", 7),
106+
("This is a test", 4),
107+
]
108+
)
109+
def test_embedding_usage_single(content, n_tokens):
110+
global server
111+
server.start()
112+
res = server.make_request("POST", "/embeddings", data={"input": content})
113+
assert res.status_code == 200
114+
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
115+
assert res.body['usage']['prompt_tokens'] == n_tokens
116+
117+
118+
def test_embedding_usage_multiple():
119+
global server
120+
server.start()
121+
res = server.make_request("POST", "/embeddings", data={
122+
"input": [
123+
"I believe the meaning of life is",
124+
"I believe the meaning of life is",
125+
],
126+
})
127+
assert res.status_code == 200
128+
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
129+
assert res.body['usage']['prompt_tokens'] == 2 * 7

examples/server/tests/unit/test_rerank.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents):
5353
})
5454
assert res.status_code == 400
5555
assert "error" in res.body
56+
57+
58+
@pytest.mark.parametrize(
59+
"query,doc1,doc2,n_tokens",
60+
[
61+
("Machine learning is", "A machine", "Learning is", 19),
62+
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
63+
]
64+
)
65+
def test_rerank_usage(query, doc1, doc2, n_tokens):
66+
global server
67+
server.start()
68+
69+
res = server.make_request("POST", "/rerank", data={
70+
"query": query,
71+
"documents": [
72+
doc1,
73+
doc2,
74+
]
75+
})
76+
assert res.status_code == 200
77+
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
78+
assert res.body['usage']['prompt_tokens'] == n_tokens

examples/server/utils.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,21 +560,24 @@ static json oaicompat_completion_params_parse(
560560

561561
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
562562
json data = json::array();
563+
int32_t n_tokens = 0;
563564
int i = 0;
564565
for (const auto & elem : embeddings) {
565566
data.push_back(json{
566567
{"embedding", json_value(elem, "embedding", json::array())},
567568
{"index", i++},
568569
{"object", "embedding"}
569570
});
571+
572+
n_tokens += json_value(elem, "tokens_evaluated", 0);
570573
}
571574

572575
json res = json {
573576
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
574577
{"object", "list"},
575-
{"usage", json { // TODO: fill
576-
{"prompt_tokens", 0},
577-
{"total_tokens", 0}
578+
{"usage", json {
579+
{"prompt_tokens", n_tokens},
580+
{"total_tokens", n_tokens}
578581
}},
579582
{"data", data}
580583
};
@@ -584,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
584587

585588
static json format_response_rerank(const json & request, const json & ranks) {
586589
json data = json::array();
590+
int32_t n_tokens = 0;
587591
int i = 0;
588592
for (const auto & rank : ranks) {
589593
data.push_back(json{
590594
{"index", i++},
591595
{"relevance_score", json_value(rank, "score", 0.0)},
592596
});
597+
598+
n_tokens += json_value(rank, "tokens_evaluated", 0);
593599
}
594600

595601
json res = json {
596602
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
597603
{"object", "list"},
598-
{"usage", json { // TODO: fill
599-
{"prompt_tokens", 0},
600-
{"total_tokens", 0}
604+
{"usage", json {
605+
{"prompt_tokens", n_tokens},
606+
{"total_tokens", n_tokens}
601607
}},
602608
{"results", data}
603609
};

0 commit comments

Comments
 (0)