Skip to content

Commit 00ae55b

Browse files
committed
server : hide ctx_sampling->prev behind API (#3696)
1 parent 3d6a687 commit 00ae55b

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

examples/server/server.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,8 @@ struct llama_server_context
15591559

15601560
if (!slot.params.cache_prompt)
15611561
{
1562-
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0);
1562+
llama_sampling_reset(slot.ctx_sampling);
1563+
15631564
slot.n_past = 0;
15641565
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
15651566
}
@@ -1570,16 +1571,17 @@ struct llama_server_context
15701571
slot.params.n_keep = slot.num_prompt_tokens;
15711572
}
15721573
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
1573-
//if input prompt is too big, truncate like normal
1574+
1575+
// if input prompt is too big, truncate it
15741576
if (slot.num_prompt_tokens >= slot.n_ctx)
15751577
{
1576-
// applied bug of #3661
15771578
const int n_left = slot.n_ctx - slot.params.n_keep;
15781579
const int n_block_size = n_left / 2;
15791580
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
1581+
15801582
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
1581-
// Use half the left-over space in the context for the prompt
15821583
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
1584+
15831585
LOG_VERBOSE("input truncated", {
15841586
{"n_ctx", slot.n_ctx},
15851587
{"n_keep", slot.params.n_keep},
@@ -1588,14 +1590,20 @@ struct llama_server_context
15881590
});
15891591
slot.truncated = true;
15901592
prompt_tokens = new_tokens;
1593+
15911594
slot.num_prompt_tokens = prompt_tokens.size();
15921595
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
15931596
}
1594-
const size_t ps = slot.num_prompt_tokens;
1595-
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0);
1596-
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps);
1597+
1598+
// push the prompt into the sampling context (do not apply grammar)
1599+
for (auto &token : prompt_tokens)
1600+
{
1601+
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
1602+
}
1603+
15971604
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
15981605
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
1606+
15991607
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
16001608
}
16011609

0 commit comments

Comments
 (0)