Skip to content

Commit c82d18e

Browse files
authored
server : embeddings compatibility for OpenAI (#5190)
1 parent 14fef85 commit c82d18e

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

examples/server/oai.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,18 @@ inline static std::vector<json> format_partial_response_oaicompat(const task_res
206206

207207
return std::vector<json>({ret});
208208
}
209+
210+
inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings)
211+
{
212+
json res =
213+
json{
214+
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
215+
{"object", "list"},
216+
{"usage",
217+
json{{"prompt_tokens", 0},
218+
{"total_tokens", 0}}},
219+
{"data", embeddings}
220+
};
221+
return res;
222+
}
223+

examples/server/server.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2929,6 +2929,66 @@ int main(int argc, char **argv)
29292929
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
29302930
});
29312931

2932+
svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
2933+
{
2934+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2935+
const json body = json::parse(req.body);
2936+
2937+
json prompt;
2938+
if (body.count("input") != 0)
2939+
{
2940+
prompt = body["input"];
2941+
// batch
2942+
if(prompt.is_array()) {
2943+
json data = json::array();
2944+
int i = 0;
2945+
for (const json &elem : prompt) {
2946+
const int task_id = llama.queue_tasks.get_new_id();
2947+
llama.queue_results.add_waiting_task_id(task_id);
2948+
llama.request_completion(task_id, { {"prompt", elem}, { "n_predict", 0} }, false, true, -1);
2949+
2950+
// get the result
2951+
task_result result = llama.queue_results.recv(task_id);
2952+
llama.queue_results.remove_waiting_task_id(task_id);
2953+
2954+
json embedding = json{
2955+
{"embedding", json_value(result.result_json, "embedding", json::array())},
2956+
{"index", i++},
2957+
{"object", "embedding"}
2958+
};
2959+
data.push_back(embedding);
2960+
}
2961+
json result = format_embeddings_response_oaicompat(body, data);
2962+
return res.set_content(result.dump(), "application/json; charset=utf-8");
2963+
}
2964+
}
2965+
else
2966+
{
2967+
prompt = "";
2968+
}
2969+
2970+
// create and queue the task
2971+
const int task_id = llama.queue_tasks.get_new_id();
2972+
llama.queue_results.add_waiting_task_id(task_id);
2973+
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}}, false, true, -1);
2974+
2975+
// get the result
2976+
task_result result = llama.queue_results.recv(task_id);
2977+
llama.queue_results.remove_waiting_task_id(task_id);
2978+
2979+
json data = json::array({json{
2980+
{"embedding", json_value(result.result_json, "embedding", json::array())},
2981+
{"index", 0},
2982+
{"object", "embedding"}
2983+
}}
2984+
);
2985+
2986+
json root = format_embeddings_response_oaicompat(body, data);
2987+
2988+
// send the result
2989+
return res.set_content(root.dump(), "application/json; charset=utf-8");
2990+
});
2991+
29322992
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
29332993
// "Bus error: 10" - this is on macOS, it does not crash on Linux
29342994
//std::thread t2([&]()

0 commit comments

Comments
 (0)