Skip to content

Commit 1cfb8bb

Browse files
committed
cont : assert that sequence positions are not decreasing
ggml-ci
1 parent 7286558 commit 1cfb8bb

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

src/llama-batch.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ bool llama_batch_allocr::init(
282282
}
283283
}
284284

285-
// disallow disjoint sequence sets:
285+
// disallow partial sequence sub-sets:
286286
//
287287
// invalid: x
288288
// i: 0 1 2 ...
@@ -291,28 +291,46 @@ bool llama_batch_allocr::init(
291291
// seq_id[i][1]: 1 1 2
292292
// seq_id[i][2]: 2
293293
//
294+
// disallow decreasing sequence positions:
295+
//
296+
// invalid: x
297+
// i: 0 1 2 3 4 5 6 ...
298+
// ---------------------------------------
299+
// pos[i]: 4 5 0 1 6 2 3
300+
// seq_id[i][0]: 0 0 1 1 0 1 0
301+
//
294302
{
295303
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
296304
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
297305
cur_seq_set[s].set();
298306
}
299307

308+
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
309+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
310+
cur_seq_pos[s] = -1;
311+
}
312+
300313
for (int32_t i = 0; i < batch.n_tokens; ++i) {
314+
const llama_pos pos = batch.pos[i];
315+
301316
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
302317
const llama_seq_id seq_id = batch.seq_id[i][s];
303318

304319
cur_seq_set[seq_id] &= seq_set[i];
305320

306321
if (cur_seq_set[seq_id].none()) {
307-
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets\n", __func__, seq_id);
322+
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
323+
return false;
324+
}
325+
326+
if (pos < cur_seq_pos[seq_id]) {
327+
LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
308328
return false;
309329
}
310330
}
311331
}
312332
}
313333

314-
// TODO: check that positions are increasing
315-
316334
split_reset();
317335

318336
return true;

0 commit comments

Comments
 (0)