Skip to content

Commit e230e51

Browse files
committed
server : update batching logic to reset n_batch on successful decode
1 parent f2ded9d commit e230e51

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

examples/parallel/parallel.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,15 +362,17 @@ int main(int argc, char ** argv) {
362362
// process in chunks of params.n_batch
363363
int32_t n_batch = params.n_batch;
364364

365-
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
365+
int32_t i_next = 0;
366+
367+
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
366368
// experiment: process in powers of 2
367369
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
368370
// n_batch /= 2;
369371
// i -= n_batch;
370372
// continue;
371373
//}
372374

373-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
375+
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
374376

375377
llama_batch batch_view = {
376378
n_tokens,
@@ -396,13 +398,18 @@ int main(int argc, char ** argv) {
396398

397399
// retry with half the batch size to try to find a free slot in the KV cache
398400
n_batch /= 2;
399-
i -= n_batch;
400401

401402
continue;
402403
}
403404

404405
LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
405406

407+
// move the head of the batch forward with the number of tokens we just processed
408+
i_next = i + n_tokens;
409+
410+
// on successful decode, restore the original batch size
411+
n_batch = params.n_batch;
412+
406413
for (auto & client : clients) {
407414
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
408415
continue;

tools/server/server.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3381,8 +3381,10 @@ struct server_context {
33813381
}
33823382
}
33833383

3384+
int32_t i_next = 0;
3385+
33843386
// process the created batch of tokens
3385-
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
3387+
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
33863388
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
33873389

33883390
llama_batch batch_view = {
@@ -3430,11 +3432,15 @@ struct server_context {
34303432

34313433
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
34323434

3433-
i -= n_batch;
3434-
34353435
continue; // continue loop of n_batch
34363436
}
34373437

3438+
// move the head of the batch forward with the number of tokens we just processed
3439+
i_next = i + n_tokens;
3440+
3441+
// on successful decode, restore the original batch size
3442+
n_batch = llama_n_batch(ctx);
3443+
34383444
for (auto & slot : slots) {
34393445
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
34403446
continue; // continue loop of slots

0 commit comments

Comments
 (0)