Skip to content

Commit fde9b8d

Browse files
committed
server : improve infill context reuse
ggml-ci
1 parent 4c42f93 commit fde9b8d

File tree

2 files changed

+32
-49
lines changed

2 files changed

+32
-49
lines changed

examples/server/README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -524,10 +524,12 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
524524

525525
- `input_prefix`: Set the prefix of the code to infill.
526526
- `input_suffix`: Set the suffix of the code to infill.
527-
- `prompt`: Added after the `FIM_MID` token
528-
- `extra_context`: Additional context inserted before the FIM prefix. See https://github.com/ggerganov/llama.cpp/pull/9874
527+
- `input_extra`: Additional context inserted before the FIM prefix.
528+
- `prompt`: Added after the `FIM_MID` token
529529

530-
It also accepts all the options of `/completion`.
530+
`input_extra` is array of `{"filename": string, "text": string}` objects.
531+
532+
The endpoint also accepts all the options of `/completion`.
531533

532534
If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:
533535

@@ -545,7 +547,7 @@ If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](
545547
If the tokens are missing, then the extra context is simply prefixed at the start:
546548

547549
```txt
548-
[extra_context]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
550+
[input_extra]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
549551
```
550552

551553
### **GET** `/props`: Get server global properties.

examples/server/server.cpp

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,6 @@ struct slot_params {
136136
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
137137

138138
std::vector<std::string> antiprompt;
139-
140-
json input_prefix;
141-
json input_suffix;
142-
json extra_context;
143139
};
144140

145141
struct server_slot {
@@ -169,6 +165,10 @@ struct server_slot {
169165

170166
json prompt; // can be either a string, array of strings or array of token ids
171167

168+
json input_prefix;
169+
json input_suffix;
170+
json input_extra;
171+
172172
// when a task is submitted, we first tokenize the prompt and store it here
173173
std::vector<llama_token> prompt_tokens;
174174
std::vector<llama_token> extra_tokens;
@@ -908,12 +908,12 @@ struct server_context {
908908
}
909909

910910
// infill
911-
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
912-
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
913-
slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);
911+
slot.input_prefix = json_value(data, "input_prefix", json());
912+
slot.input_suffix = json_value(data, "input_suffix", json());
913+
slot.input_extra = json_value(data, "input_extra", json());
914914

915-
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
916-
for (const auto & chunk : slot.params.extra_context) {
915+
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
916+
for (const auto & chunk : slot.input_extra) {
917917
// { "text": string, "filename": string }
918918
if (!chunk.contains("text") || !chunk["text"].is_string()) {
919919
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
@@ -930,7 +930,7 @@ struct server_context {
930930
}
931931

932932
// get prompt
933-
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
933+
{
934934
const auto & prompt = data.find("prompt");
935935
if (prompt == data.end()) {
936936
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
@@ -1954,6 +1954,8 @@ struct server_context {
19541954
} break;
19551955
case SERVER_TASK_CMPL_TYPE_INFILL:
19561956
{
1957+
// TODO: optimize this block by reducing memory allocations and movement
1958+
19571959
// use FIM repo-level pattern:
19581960
// ref: https://arxiv.org/pdf/2409.12186
19591961
//
@@ -1964,10 +1966,11 @@ struct server_context {
19641966
// extra chunk 1
19651967
// ...
19661968
// [FIM_SEP]filename
1967-
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1969+
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
19681970
//
1969-
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
1970-
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
1971+
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
1972+
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
1973+
auto tokens_prompt = tokenize(slot.prompt, false, false);
19711974

19721975
slot.extra_tokens.clear();
19731976
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
@@ -1977,7 +1980,7 @@ struct server_context {
19771980
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
19781981
}
19791982

1980-
for (const auto & chunk : slot.params.extra_context) {
1983+
for (const auto & chunk : slot.input_extra) {
19811984
// { "text": string, "filename": string }
19821985
const std::string text = chunk.value("text", "");
19831986
const std::string filename = chunk.value("filename", "tmp");
@@ -2008,20 +2011,21 @@ struct server_context {
20082011
}
20092012

20102013
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2011-
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
2012-
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
2014+
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
2015+
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
20132016

20142017
// fill the rest of the context with extra chunks
20152018
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
20162019

2017-
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
2018-
suffix_tokens.resize(n_suffix_take);
2020+
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
2021+
tokens_suffix.resize(n_suffix_take);
20192022

2020-
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
2021-
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
2023+
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
2024+
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
2025+
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
20222026

2023-
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
2024-
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
2027+
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
2028+
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
20252029

20262030
if (llama_add_bos_token(model)) {
20272031
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
@@ -2136,34 +2140,11 @@ struct server_context {
21362140

21372141
while (head_c < slot.cache_tokens.size() &&
21382142
head_p < prompt_tokens.size()) {
2139-
if (llama_token_is_control(model, slot.cache_tokens[head_c]) &&
2140-
slot.cache_tokens[head_c] != llama_token_fim_rep(model) &&
2141-
slot.cache_tokens[head_c] != llama_token_fim_sep(model)) {
2142-
break;
2143-
}
2144-
2145-
if (llama_token_is_control(model, prompt_tokens[head_p]) &&
2146-
prompt_tokens[head_p] != llama_token_fim_rep(model) &&
2147-
prompt_tokens[head_p] != llama_token_fim_sep(model)) {
2148-
break;
2149-
}
21502143

21512144
size_t n_match = 0;
2152-
21532145
while (head_c + n_match < slot.cache_tokens.size() &&
21542146
head_p + n_match < prompt_tokens.size() &&
21552147
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
2156-
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
2157-
slot.cache_tokens[head_c + n_match] != llama_token_fim_rep(model) &&
2158-
slot.cache_tokens[head_c + n_match] != llama_token_fim_sep(model)) {
2159-
break;
2160-
}
2161-
2162-
if (llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
2163-
prompt_tokens[head_p + n_match] != llama_token_fim_rep(model) &&
2164-
prompt_tokens[head_p + n_match] != llama_token_fim_sep(model)) {
2165-
break;
2166-
}
21672148

21682149
n_match++;
21692150
}

0 commit comments

Comments
 (0)