Skip to content

Commit 60c6663

Browse files
authored
batch : rework llama_batch_allocr (#14153)
* batch : rework llama_batch_allocr ggml-ci * cont : move validation inside class ggml-ci * cont : move output counting to class ggml-ci * cont : minor ggml-ci * batch : add TODOs ggml-ci
1 parent b7cc774 commit 60c6663

File tree

7 files changed

+162
-106
lines changed

7 files changed

+162
-106
lines changed

src/llama-batch.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include "llama-batch.h"
22

3+
#include "llama-impl.h"
4+
#include "llama-cparams.h"
5+
#include "llama-vocab.h"
6+
37
#include <cassert>
48
#include <cstring>
59
#include <algorithm>
@@ -279,9 +283,42 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
279283
);
280284
}
281285

282-
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
283-
batch = in_batch;
286+
llama_batch_allocr::llama_batch_allocr() = default;
287+
288+
bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
289+
clear();
290+
291+
batch = batch_inp;
292+
284293
GGML_ASSERT(batch.n_tokens > 0);
294+
295+
if (!batch.pos) {
296+
if (batch.seq_id) {
297+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
298+
return false;
299+
}
300+
}
301+
302+
if (batch.token) {
303+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
304+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
305+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
306+
return false;
307+
}
308+
}
309+
}
310+
311+
if (batch.seq_id) {
312+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
313+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
314+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
315+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
316+
return false;
317+
}
318+
}
319+
}
320+
}
321+
285322
if (!batch.pos) {
286323
assert(p0 >= 0);
287324
pos.resize(batch.n_tokens);
@@ -290,13 +327,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
290327
}
291328
batch.pos = pos.data();
292329
}
330+
293331
if (!batch.n_seq_id) {
294332
n_seq_id.resize(batch.n_tokens);
295333
for (int32_t i = 0; i < batch.n_tokens; i++) {
296334
n_seq_id[i] = seq_id_0.size();
297335
}
298336
batch.n_seq_id = n_seq_id.data();
299337
}
338+
300339
if (!batch.seq_id) {
301340
seq_id.resize(batch.n_tokens + 1);
302341
seq_id[batch.n_tokens] = NULL;
@@ -305,12 +344,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
305344
}
306345
batch.seq_id = seq_id.data();
307346
}
347+
308348
if (!batch.logits) {
309349
// by default return the output only for the last token
310350
output.resize(batch.n_tokens);
311351
output[output.size() - 1] = true;
312352
batch.logits = output.data();
313353
}
354+
355+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
356+
n_outputs += batch.logits[i] != 0;
357+
}
358+
359+
return true;
360+
}
361+
362+
const llama_batch & llama_batch_allocr::get_batch() const {
363+
return batch;
364+
}
365+
366+
uint32_t llama_batch_allocr::get_n_outputs() const {
367+
return n_outputs;
368+
}
369+
370+
void llama_batch_allocr::clear() {
371+
n_outputs = 0;
372+
373+
batch = {};
374+
pos.clear();
375+
n_seq_id.clear();
376+
seq_id.clear();
377+
output.clear();
314378
}
315379

316380
//

src/llama-batch.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ struct llama_ubatch {
1818
llama_token * token; // [n_tokens]
1919
float * embd; // [n_embd, n_tokens]
2020
llama_pos * pos; // [n_tokens]
21-
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22-
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
21+
int32_t * n_seq_id; // [n_seqs]
22+
llama_seq_id ** seq_id; // [n_seqs]
2323
int8_t * output; // [n_tokens]
2424
};
2525

@@ -78,15 +78,28 @@ struct llama_sbatch {
7878
};
7979

8080
// temporary allocate memory for the input batch if needed
81-
struct llama_batch_allocr {
82-
struct llama_batch batch;
81+
class llama_batch_allocr {
82+
public:
83+
llama_batch_allocr();
84+
85+
// optionally fulfill the batch returned by llama_batch_get_one
86+
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
87+
88+
const llama_batch & get_batch() const;
89+
90+
uint32_t get_n_outputs() const;
91+
92+
private:
93+
void clear();
94+
95+
llama_batch batch;
96+
97+
uint32_t n_outputs;
8398

8499
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
100+
85101
std::vector<llama_pos> pos;
86102
std::vector<int32_t> n_seq_id;
87103
std::vector<llama_seq_id *> seq_id;
88104
std::vector<int8_t> output;
89-
90-
// optionally fulfill the batch returned by llama_batch_get_one
91-
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
92105
};

0 commit comments

Comments
 (0)