@@ -136,10 +136,6 @@ struct slot_params {
136
136
int64_t t_max_predict_ms = -1 ; // if positive, limit the generation phase to this time limit
137
137
138
138
std::vector<std::string> antiprompt;
139
-
140
- json input_prefix;
141
- json input_suffix;
142
- json extra_context;
143
139
};
144
140
145
141
struct server_slot {
@@ -169,6 +165,10 @@ struct server_slot {
169
165
170
166
json prompt; // can be either a string, array of strings or array of token ids
171
167
168
+ json input_prefix;
169
+ json input_suffix;
170
+ json input_extra;
171
+
172
172
// when a task is submitted, we first tokenize the prompt and store it here
173
173
std::vector<llama_token> prompt_tokens;
174
174
std::vector<llama_token> extra_tokens;
@@ -908,12 +908,12 @@ struct server_context {
908
908
}
909
909
910
910
// infill
911
- slot.params . input_prefix = json_value (data, " input_prefix" , default_params. input_prefix );
912
- slot.params . input_suffix = json_value (data, " input_suffix" , default_params. input_suffix );
913
- slot.params . extra_context = json_value (data, " extra_context " , default_params. extra_context );
911
+ slot.input_prefix = json_value (data, " input_prefix" , json () );
912
+ slot.input_suffix = json_value (data, " input_suffix" , json () );
913
+ slot.input_extra = json_value (data, " input_extra " , json () );
914
914
915
- SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.params . extra_context .size ());
916
- for (const auto & chunk : slot.params . extra_context ) {
915
+ SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.input_extra .size ());
916
+ for (const auto & chunk : slot.input_extra ) {
917
917
// { "text": string, "filename": string }
918
918
if (!chunk.contains (" text" ) || !chunk[" text" ].is_string ()) {
919
919
send_error (task, " extra_context chunk must contain a \" text\" field with a string value" , ERROR_TYPE_INVALID_REQUEST);
@@ -930,7 +930,7 @@ struct server_context {
930
930
}
931
931
932
932
// get prompt
933
- if (task. cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
933
+ {
934
934
const auto & prompt = data.find (" prompt" );
935
935
if (prompt == data.end ()) {
936
936
send_error (task, " \" prompt\" must be provided" , ERROR_TYPE_INVALID_REQUEST);
@@ -1954,6 +1954,8 @@ struct server_context {
1954
1954
} break ;
1955
1955
case SERVER_TASK_CMPL_TYPE_INFILL:
1956
1956
{
1957
+ // TODO: optimize this block by reducing memory allocations and movement
1958
+
1957
1959
// use FIM repo-level pattern:
1958
1960
// ref: https://arxiv.org/pdf/2409.12186
1959
1961
//
@@ -1964,10 +1966,11 @@ struct server_context {
1964
1966
// extra chunk 1
1965
1967
// ...
1966
1968
// [FIM_SEP]filename
1967
- // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1969
+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
1968
1970
//
1969
- auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
1970
- auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
1971
+ auto tokens_prefix = tokenize (slot.input_prefix , false , false );
1972
+ auto tokens_suffix = tokenize (slot.input_suffix , false , false );
1973
+ auto tokens_prompt = tokenize (slot.prompt , false , false );
1971
1974
1972
1975
slot.extra_tokens .clear ();
1973
1976
if (llama_token_fim_rep (model) != LLAMA_TOKEN_NULL) {
@@ -1977,7 +1980,7 @@ struct server_context {
1977
1980
slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_repo.begin (), k_fim_repo.end ());
1978
1981
}
1979
1982
1980
- for (const auto & chunk : slot.params . extra_context ) {
1983
+ for (const auto & chunk : slot.input_extra ) {
1981
1984
// { "text": string, "filename": string }
1982
1985
const std::string text = chunk.value (" text" , " " );
1983
1986
const std::string filename = chunk.value (" filename" , " tmp" );
@@ -2008,20 +2011,21 @@ struct server_context {
2008
2011
}
2009
2012
2010
2013
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2011
- const int n_suffix_take = std::min<int >(suffix_tokens .size (), (n_batch)/ 4 );
2012
- const int n_prefix_take = std::min<int >(prefix_tokens .size (), (n_batch - 3 ) - n_suffix_take );
2014
+ const int n_suffix_take = std::min<int >(tokens_suffix .size (), (n_batch/ 4 ) );
2015
+ const int n_prefix_take = std::min<int >(tokens_prefix .size (), 3 * (n_batch/ 4 ) - 3 );
2013
2016
2014
2017
// fill the rest of the context with extra chunks
2015
2018
const int n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx - (n_batch) - 2 *slot.n_predict ), slot.extra_tokens .size ());
2016
2019
2017
- prefix_tokens .erase (prefix_tokens .begin (), prefix_tokens .begin () + prefix_tokens .size () - n_prefix_take);
2018
- suffix_tokens .resize (n_suffix_take);
2020
+ tokens_prefix .erase (tokens_prefix .begin (), tokens_prefix .begin () + tokens_prefix .size () - n_prefix_take);
2021
+ tokens_suffix .resize (n_suffix_take);
2019
2022
2020
- prefix_tokens.insert (prefix_tokens.begin (), llama_token_fim_pre (model));
2021
- suffix_tokens.insert (suffix_tokens.begin (), llama_token_fim_suf (model));
2023
+ tokens_prefix.insert (tokens_prefix.begin (), llama_token_fim_pre (model));
2024
+ tokens_prefix.insert (tokens_prefix.end (), tokens_prompt.begin (), tokens_prompt.end ());
2025
+ tokens_suffix.insert (tokens_suffix.begin (), llama_token_fim_suf (model));
2022
2026
2023
- auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens ;
2024
- auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens ;
2027
+ auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix ;
2028
+ auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix ;
2025
2029
2026
2030
if (llama_add_bos_token (model)) {
2027
2031
embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
@@ -2136,34 +2140,11 @@ struct server_context {
2136
2140
2137
2141
while (head_c < slot.cache_tokens .size () &&
2138
2142
head_p < prompt_tokens.size ()) {
2139
- if (llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2140
- slot.cache_tokens [head_c] != llama_token_fim_rep (model) &&
2141
- slot.cache_tokens [head_c] != llama_token_fim_sep (model)) {
2142
- break ;
2143
- }
2144
-
2145
- if (llama_token_is_control (model, prompt_tokens[head_p]) &&
2146
- prompt_tokens[head_p] != llama_token_fim_rep (model) &&
2147
- prompt_tokens[head_p] != llama_token_fim_sep (model)) {
2148
- break ;
2149
- }
2150
2143
2151
2144
size_t n_match = 0 ;
2152
-
2153
2145
while (head_c + n_match < slot.cache_tokens .size () &&
2154
2146
head_p + n_match < prompt_tokens.size () &&
2155
2147
slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2156
- if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2157
- slot.cache_tokens [head_c + n_match] != llama_token_fim_rep (model) &&
2158
- slot.cache_tokens [head_c + n_match] != llama_token_fim_sep (model)) {
2159
- break ;
2160
- }
2161
-
2162
- if (llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2163
- prompt_tokens[head_p + n_match] != llama_token_fim_rep (model) &&
2164
- prompt_tokens[head_p + n_match] != llama_token_fim_sep (model)) {
2165
- break ;
2166
- }
2167
2148
2168
2149
n_match++;
2169
2150
}
0 commit comments