@@ -2005,6 +2005,11 @@ struct server_context {
2005
2005
int32_t n_batch = llama_n_batch (ctx);
2006
2006
int32_t n_ubatch = llama_n_ubatch (ctx);
2007
2007
2008
+ // track if this is an embedding or non-embedding batch
2009
+ // if we've added sampled tokens above, we are in non-embedding mode
2010
+ // -1: none, 0: non-embedding, 1: embedding
2011
+ int32_t batch_type = batch.n_tokens > 0 ? 0 : -1 ;
2012
+
2008
2013
// next, batch any pending prompts without exceeding n_batch
2009
2014
if (params.cont_batching || batch.n_tokens == 0 ) {
2010
2015
for (auto & slot : slots) {
@@ -2175,6 +2180,14 @@ struct server_context {
2175
2180
}
2176
2181
}
2177
2182
2183
+ // check that we are in the right batch_type, if not defer the slot
2184
+ bool slot_type = slot.embedding ? 1 : 0 ;
2185
+ if (batch_type == -1 ) {
2186
+ batch_type = slot_type;
2187
+ } else if (batch_type != slot_type) {
2188
+ continue ;
2189
+ }
2190
+
2178
2191
// keep only the common part
2179
2192
int p0 = (int ) system_tokens.size () + slot.n_past ;
2180
2193
if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , p0, -1 )) {
@@ -2276,6 +2289,9 @@ struct server_context {
2276
2289
{" n_tokens" , batch.n_tokens },
2277
2290
});
2278
2291
2292
+ // make sure we're in the right embedding mode
2293
+ llama_set_embeddings (ctx, batch_type == 1 );
2294
+
2279
2295
// process the created batch of tokens
2280
2296
for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
2281
2297
const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
@@ -2990,6 +3006,11 @@ int main(int argc, char ** argv) {
2990
3006
};
2991
3007
2992
3008
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3009
+ if (ctx_server.params .embedding ) {
3010
+ res_error (res, format_error_response (" This server does not support completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
3011
+ return ;
3012
+ }
3013
+
2993
3014
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
2994
3015
2995
3016
json data = json::parse (req.body );
@@ -3085,6 +3106,11 @@ int main(int argc, char ** argv) {
3085
3106
};
3086
3107
3087
3108
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
3109
+ if (ctx_server.params .embedding ) {
3110
+ res_error (res, format_error_response (" This server does not support chat completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
3111
+ return ;
3112
+ }
3113
+
3088
3114
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3089
3115
json data = oaicompat_completion_params_parse (ctx_server.model , json::parse (req.body ), params.chat_template );
3090
3116
@@ -3157,6 +3183,11 @@ int main(int argc, char ** argv) {
3157
3183
};
3158
3184
3159
3185
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3186
+ if (ctx_server.params .embedding ) {
3187
+ res_error (res, format_error_response (" This server does not support infill. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
3188
+ return ;
3189
+ }
3190
+
3160
3191
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3161
3192
3162
3193
json data = json::parse (req.body );
@@ -3243,13 +3274,8 @@ int main(int argc, char ** argv) {
3243
3274
return res.set_content (data.dump (), " application/json; charset=utf-8" );
3244
3275
};
3245
3276
3246
- const auto handle_embeddings = [¶ms, & ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3277
+ const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3247
3278
res.set_header (" Access-Control-Allow-Origin" , req.get_header_value (" Origin" ));
3248
- if (!params.embedding ) {
3249
- res.status = 501 ;
3250
- res.set_content (" This server does not support embeddings. Start it with `--embeddings`" , " text/plain; charset=utf-8" );
3251
- return ;
3252
- }
3253
3279
3254
3280
const json body = json::parse (req.body );
3255
3281
bool is_openai = false ;
0 commit comments