Skip to content

Commit f762501

Browse files
authored
server : fix crash when system prompt is bigger than batch size (#5714)
The system prompt is now decoded in batches. * server : fix off-by-one n_past when start of prompt matches whole cache The tokens right after the matching part would otherwise skip a pos value.
1 parent abbabc5 commit f762501

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

examples/server/server.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -902,10 +902,24 @@ struct llama_server_context
902902
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
903903
}
904904

905-
if (llama_decode(ctx, batch) != 0)
905+
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
906906
{
907-
LOG_TEE("%s: llama_decode() failed\n", __func__);
908-
return;
907+
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
908+
llama_batch batch_view = {
909+
n_tokens,
910+
batch.token + i,
911+
nullptr,
912+
batch.pos + i,
913+
batch.n_seq_id + i,
914+
batch.seq_id + i,
915+
batch.logits + i,
916+
0, 0, 0, // unused
917+
};
918+
if (llama_decode(ctx, batch_view) != 0)
919+
{
920+
LOG_TEE("%s: llama_decode() failed\n", __func__);
921+
return;
922+
}
909923
}
910924

911925
// assign the system KV cache to all parallel sequences
@@ -1785,6 +1799,14 @@ struct llama_server_context
17851799
}
17861800

17871801
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
1802+
1803+
// the last token of the cache is not in the KV cache until the next call to llama_decode
1804+
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
1805+
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size())
1806+
{
1807+
slot.n_past -= 1;
1808+
}
1809+
17881810
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
17891811

17901812
if (slot.ga_n != 1)

0 commit comments

Comments
 (0)