Skip to content

Commit 1adcd4b

Browse files
committed
llama : auto-batch
ggml-ci
1 parent 2252eef commit 1adcd4b

File tree

3 files changed

+80
-87
lines changed

3 files changed

+80
-87
lines changed

src/llama-context.cpp

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,9 @@ const llama_kv_cache * llama_context::get_kv_self() const {
420420
return kv_self;
421421
}
422422

423-
void llama_context::kv_self_update() {
423+
bool llama_context::kv_self_update() {
424424
if (!memory) {
425-
return;
425+
return false;
426426
}
427427

428428
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
@@ -438,7 +438,11 @@ void llama_context::kv_self_update() {
438438
if (!gf) {
439439
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
440440
}
441+
442+
return true;
441443
}
444+
445+
return false;
442446
}
443447

444448
enum llama_pooling_type llama_context::pooling_type() const {
@@ -891,25 +895,53 @@ int llama_context::decode(llama_batch & inp_batch) {
891895
// handle any pending defrags/shifts
892896
kv_self_update();
893897

894-
auto decode_state = kv_self->init(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
895-
if (!decode_state) {
896-
return -2;
897-
}
898+
llama_memory_decode_state_ptr decode_state;
898899

899-
switch (decode_state->get_status()) {
900-
case LLAMA_MEMORY_STATUS_SUCCESS:
901-
{
902-
} break;
903-
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
904-
{
905-
// not a fatal error, we can re-try with a different batch
906-
return 1;
907-
}
908-
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
909-
{
910-
return -2;
911-
}
912-
}
900+
bool did_defrag = false;
901+
auto n_ubatch = cparams.n_ubatch;
902+
903+
do {
904+
decode_state = kv_self->init(batch, n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
905+
if (!decode_state) {
906+
return -2;
907+
}
908+
909+
switch (decode_state->get_status()) {
910+
case LLAMA_MEMORY_STATUS_SUCCESS:
911+
{
912+
} break;
913+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
914+
{
915+
if (!did_defrag) {
916+
did_defrag = true;
917+
918+
kv_self->defrag_sched(-1.0f);
919+
if (kv_self_update()) {
920+
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
921+
922+
continue;
923+
}
924+
}
925+
926+
if (n_ubatch > 1) {
927+
n_ubatch /= 2;
928+
929+
LLAMA_LOG_DEBUG("%s: failed to find free space in the KV cache, retrying with smaller ubatch size: n_ubatch = %d\n", __func__, n_ubatch);
930+
continue;
931+
}
932+
933+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
934+
935+
return 1;
936+
}
937+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
938+
{
939+
return -2;
940+
}
941+
}
942+
943+
break;
944+
} while(true);
913945

914946
// reserve output buffer
915947
if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -2588,22 +2620,8 @@ int32_t llama_encode(
25882620
int32_t llama_decode(
25892621
llama_context * ctx,
25902622
llama_batch batch) {
2591-
int ret = ctx->decode(batch);
2592-
2593-
// defrag and try again
2594-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2595-
if (ret == 1) {
2596-
llama_kv_self_defrag(ctx);
2597-
ret = ctx->decode(batch);
2598-
2599-
if (ret == 1) {
2600-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2601-
2602-
return ret;
2603-
}
2604-
}
2605-
2606-
if (ret != 0) {
2623+
const int ret = ctx->decode(batch);
2624+
if (ret != 0 && ret != 1) {
26072625
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26082626
}
26092627

src/llama-context.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ struct llama_context {
4747
llama_kv_cache * get_kv_self();
4848
const llama_kv_cache * get_kv_self() const;
4949

50+
// return true of the KV cache was updated
5051
// TODO: remove
51-
void kv_self_update();
52+
bool kv_self_update();
5253

5354
enum llama_pooling_type pooling_type() const;
5455

tools/server/server.cpp

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3384,75 +3384,49 @@ struct server_context {
33843384
}
33853385

33863386
// process the created batch of tokens
3387-
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
3388-
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
3389-
3390-
llama_batch batch_view = {
3391-
n_tokens,
3392-
batch.token + i,
3393-
nullptr,
3394-
batch.pos + i,
3395-
batch.n_seq_id + i,
3396-
batch.seq_id + i,
3397-
batch.logits + i,
3398-
};
3399-
3400-
const int ret = llama_decode(ctx, batch_view);
3401-
3402-
metrics.on_decoded(slots);
3387+
{
3388+
const int ret = llama_decode(ctx, batch);
34033389

34043390
if (ret != 0) {
3405-
{
3406-
std::string err;
3407-
3408-
if (n_batch == 1 && ret == 1) {
3409-
err = "Context size has been exceeded.";
3410-
}
3411-
3412-
if (ret == -1) {
3413-
err = "Invalid input batch.";
3414-
}
3391+
std::string err;
34153392

3416-
if (ret < -1) {
3417-
err = "Compute error.";
3418-
}
3419-
3420-
if (!err.empty()) {
3421-
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
3422-
for (auto & slot : slots) {
3423-
slot.release();
3424-
send_error(slot, err);
3425-
}
3426-
break;
3427-
}
3393+
if (ret == 1) {
3394+
err = "Context size has been exceeded.";
34283395
}
34293396

3430-
// retry with half the batch size to try to find a free slot in the KV cache
3431-
n_batch /= 2;
3397+
if (ret == -1) {
3398+
err = "Invalid input batch.";
3399+
}
34323400

3433-
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
3401+
if (ret < -1) {
3402+
err = "Compute error.";
3403+
}
34343404

3435-
i -= n_batch;
3405+
if (!err.empty()) {
3406+
SRV_ERR("%s, n_batch = %d, ret = %d\n", err.c_str(), n_batch, ret);
3407+
for (auto & slot : slots) {
3408+
slot.release();
3409+
send_error(slot, err);
3410+
}
34363411

3437-
continue; // continue loop of n_batch
3412+
return;
3413+
}
34383414
}
34393415

3440-
for (auto & slot : slots) {
3441-
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
3442-
continue; // continue loop of slots
3443-
}
3416+
metrics.on_decoded(slots);
34443417

3418+
for (auto & slot : slots) {
34453419
if (slot.state == SLOT_STATE_DONE_PROMPT) {
34463420
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
34473421
// prompt evaluated for embedding
3448-
send_embedding(slot, batch_view);
3422+
send_embedding(slot, batch);
34493423
slot.release();
34503424
slot.i_batch = -1;
34513425
continue; // continue loop of slots
34523426
}
34533427

34543428
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
3455-
send_rerank(slot, batch_view);
3429+
send_rerank(slot, batch);
34563430
slot.release();
34573431
slot.i_batch = -1;
34583432
continue; // continue loop of slots
@@ -3464,7 +3438,7 @@ struct server_context {
34643438
continue; // continue loop of slots
34653439
}
34663440

3467-
const int tok_idx = slot.i_batch - i;
3441+
const int tok_idx = slot.i_batch;
34683442

34693443
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
34703444

0 commit comments

Comments
 (0)