Skip to content

server: Ensure batches are either all embed or all completion (#8076) #8420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2003,6 +2003,11 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);

// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;

// next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
Expand Down Expand Up @@ -2173,6 +2178,14 @@ struct server_context {
}
}

// check that we are in the right batch_type, if not defer the slot
bool slot_type = slot.embedding ? 1 : 0;
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
continue;
}
Comment on lines +2181 to +2187
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be done earlier? There's some things like slot.n_past which should probably not be changed when the batch_type is not the right one.

I feel like this check should be done at least before tokens are put in the batch, but the multitude of loops over slots does make it hard to find one right spot.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. But we also want to make sure we don't assign the batch to a particular type but then bail out later on (say, because the prompt is bigger than the batch size).

My current read is that everything before where the check is currently is tokenization, and it's fine to do that, bail, and then pick up the slot on the next go around. That's why it's okay to have the conditional continue just above it (2173-2178) too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

everything before where the check is currently is tokenization and it's fine to do that, bail, and then pick up the slot on the next go around.

Right, and slot.n_past is recalculated anyway, so you're right.

But I think if batch.n_tokens is non-zero, the initial batch_type should be non-embedding, or else this could potentially attempt continuous batching with embeddings and previous text completions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, just pushed the fix!


// keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
Expand Down Expand Up @@ -2274,6 +2287,9 @@ struct server_context {
{"n_tokens", batch.n_tokens},
});

// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);

// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
Expand Down Expand Up @@ -2988,6 +3004,11 @@ int main(int argc, char ** argv) {
};

const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}

res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

json data = json::parse(req.body);
Expand Down Expand Up @@ -3083,6 +3104,11 @@ int main(int argc, char ** argv) {
};

const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}

res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);

Expand Down Expand Up @@ -3155,6 +3181,11 @@ int main(int argc, char ** argv) {
};

const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}

res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

json data = json::parse(req.body);
Expand Down Expand Up @@ -3241,13 +3272,8 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
};

const auto handle_embeddings = [&params, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
return;
}

const json body = json::parse(req.body);
bool is_openai = false;
Expand Down
Loading