@@ -282,7 +282,7 @@ bool llama_batch_allocr::init(
282
282
}
283
283
}
284
284
285
- // disallow disjoint sequence sets:
285
+ // disallow partial sequence sub- sets:
286
286
//
287
287
// invalid: x
288
288
// i: 0 1 2 ...
@@ -291,28 +291,46 @@ bool llama_batch_allocr::init(
291
291
// seq_id[i][1]: 1 1 2
292
292
// seq_id[i][2]: 2
293
293
//
294
+ // disallow decreasing sequence positions:
295
+ //
296
+ // invalid: x
297
+ // i: 0 1 2 3 4 5 6 ...
298
+ // ---------------------------------------
299
+ // pos[i]: 4 5 0 1 6 2 3
300
+ // seq_id[i][0]: 0 0 1 1 0 1 0
301
+ //
294
302
{
295
303
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
296
304
for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
297
305
cur_seq_set[s].set ();
298
306
}
299
307
308
+ llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
309
+ for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
310
+ cur_seq_pos[s] = -1 ;
311
+ }
312
+
300
313
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
314
+ const llama_pos pos = batch.pos [i];
315
+
301
316
for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
302
317
const llama_seq_id seq_id = batch.seq_id [i][s];
303
318
304
319
cur_seq_set[seq_id] &= seq_set[i];
305
320
306
321
if (cur_seq_set[seq_id].none ()) {
307
- LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets\n " , __func__, seq_id);
322
+ LLAMA_LOG_ERROR (" %s: sequence %d belongs to incompatible sequence sets (not allowed)\n " , __func__, seq_id);
323
+ return false ;
324
+ }
325
+
326
+ if (pos < cur_seq_pos[seq_id]) {
327
+ LLAMA_LOG_ERROR (" %s: sequence %d positions are decreasing (not allowed)\n " , __func__, seq_id);
308
328
return false ;
309
329
}
310
330
}
311
331
}
312
332
}
313
333
314
- // TODO: check that positions are increasing
315
-
316
334
split_reset ();
317
335
318
336
return true ;
0 commit comments