Skip to content

Commit c3ebcfa

Browse files
authored
server : ensure batches are either all embed or all completion (#8420)
* make sure batches are all embed or all non-embed * non-embedding batch for sampled tokens; fix unused params warning
1 parent 8a4441e commit c3ebcfa

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

examples/server/server.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,11 @@ struct server_context {
20052005
int32_t n_batch = llama_n_batch(ctx);
20062006
int32_t n_ubatch = llama_n_ubatch(ctx);
20072007

2008+
// track if this is an embedding or non-embedding batch
2009+
// if we've added sampled tokens above, we are in non-embedding mode
2010+
// -1: none, 0: non-embedding, 1: embedding
2011+
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
2012+
20082013
// next, batch any pending prompts without exceeding n_batch
20092014
if (params.cont_batching || batch.n_tokens == 0) {
20102015
for (auto & slot : slots) {
@@ -2175,6 +2180,14 @@ struct server_context {
21752180
}
21762181
}
21772182

2183+
// check that we are in the right batch_type, if not defer the slot
2184+
bool slot_type = slot.embedding ? 1 : 0;
2185+
if (batch_type == -1) {
2186+
batch_type = slot_type;
2187+
} else if (batch_type != slot_type) {
2188+
continue;
2189+
}
2190+
21782191
// keep only the common part
21792192
int p0 = (int) system_tokens.size() + slot.n_past;
21802193
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
@@ -2276,6 +2289,9 @@ struct server_context {
22762289
{"n_tokens", batch.n_tokens},
22772290
});
22782291

2292+
// make sure we're in the right embedding mode
2293+
llama_set_embeddings(ctx, batch_type == 1);
2294+
22792295
// process the created batch of tokens
22802296
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
22812297
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@@ -2990,6 +3006,11 @@ int main(int argc, char ** argv) {
29903006
};
29913007

29923008
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3009+
if (ctx_server.params.embedding) {
3010+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
3011+
return;
3012+
}
3013+
29933014
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
29943015

29953016
json data = json::parse(req.body);
@@ -3085,6 +3106,11 @@ int main(int argc, char ** argv) {
30853106
};
30863107

30873108
const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
3109+
if (ctx_server.params.embedding) {
3110+
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
3111+
return;
3112+
}
3113+
30883114
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
30893115
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
30903116

@@ -3157,6 +3183,11 @@ int main(int argc, char ** argv) {
31573183
};
31583184

31593185
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3186+
if (ctx_server.params.embedding) {
3187+
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
3188+
return;
3189+
}
3190+
31603191
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
31613192

31623193
json data = json::parse(req.body);
@@ -3243,13 +3274,8 @@ int main(int argc, char ** argv) {
32433274
return res.set_content(data.dump(), "application/json; charset=utf-8");
32443275
};
32453276

3246-
const auto handle_embeddings = [&params, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3277+
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
32473278
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3248-
if (!params.embedding) {
3249-
res.status = 501;
3250-
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
3251-
return;
3252-
}
32533279

32543280
const json body = json::parse(req.body);
32553281
bool is_openai = false;

0 commit comments

Comments
 (0)