1
1
#include " llama-batch.h"
2
2
3
+ #include " llama-impl.h"
4
+ #include " llama-cparams.h"
5
+ #include " llama-vocab.h"
6
+
3
7
#include < cassert>
4
8
#include < cstring>
5
9
#include < algorithm>
@@ -279,9 +283,42 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
279
283
);
280
284
}
281
285
282
- llama_batch_allocr::llama_batch_allocr (struct llama_batch in_batch, llama_pos p0) {
283
- batch = in_batch;
286
+ llama_batch_allocr::llama_batch_allocr () = default;
287
+
288
+ bool llama_batch_allocr::init (const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
289
+ clear ();
290
+
291
+ batch = batch_inp;
292
+
284
293
GGML_ASSERT (batch.n_tokens > 0 );
294
+
295
+ if (!batch.pos ) {
296
+ if (batch.seq_id ) {
297
+ LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
298
+ return false ;
299
+ }
300
+ }
301
+
302
+ if (batch.token ) {
303
+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
304
+ if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= vocab.n_tokens ()) {
305
+ LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
306
+ return false ;
307
+ }
308
+ }
309
+ }
310
+
311
+ if (batch.seq_id ) {
312
+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
313
+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
314
+ if (batch.seq_id && (batch.seq_id [i][s] < 0 || batch.seq_id [i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
315
+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d][%d] = %d > %d\n " , __func__, i, s, batch.seq_id [i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
316
+ return false ;
317
+ }
318
+ }
319
+ }
320
+ }
321
+
285
322
if (!batch.pos ) {
286
323
assert (p0 >= 0 );
287
324
pos.resize (batch.n_tokens );
@@ -290,13 +327,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
290
327
}
291
328
batch.pos = pos.data ();
292
329
}
330
+
293
331
if (!batch.n_seq_id ) {
294
332
n_seq_id.resize (batch.n_tokens );
295
333
for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
296
334
n_seq_id[i] = seq_id_0.size ();
297
335
}
298
336
batch.n_seq_id = n_seq_id.data ();
299
337
}
338
+
300
339
if (!batch.seq_id ) {
301
340
seq_id.resize (batch.n_tokens + 1 );
302
341
seq_id[batch.n_tokens ] = NULL ;
@@ -305,12 +344,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
305
344
}
306
345
batch.seq_id = seq_id.data ();
307
346
}
347
+
308
348
if (!batch.logits ) {
309
349
// by default return the output only for the last token
310
350
output.resize (batch.n_tokens );
311
351
output[output.size () - 1 ] = true ;
312
352
batch.logits = output.data ();
313
353
}
354
+
355
+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
356
+ n_outputs += batch.logits [i] != 0 ;
357
+ }
358
+
359
+ return true ;
360
+ }
361
+
362
+ const llama_batch & llama_batch_allocr::get_batch () const {
363
+ return batch;
364
+ }
365
+
366
+ uint32_t llama_batch_allocr::get_n_outputs () const {
367
+ return n_outputs;
368
+ }
369
+
370
+ void llama_batch_allocr::clear () {
371
+ n_outputs = 0 ;
372
+
373
+ batch = {};
374
+ pos.clear ();
375
+ n_seq_id.clear ();
376
+ seq_id.clear ();
377
+ output.clear ();
314
378
}
315
379
316
380
//
0 commit comments