@@ -406,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
406
406
407
407
bool success = true ;
408
408
409
- // TODO: here we have to verify that all ubatches can fit in the cells
410
- // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
411
- // during the compute of each ubatch. to reproduce, uncomment the following loop and run:
412
- //
413
- // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
414
- //
415
- // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
416
- //
417
- GGML_UNUSED (ubatches);
418
- // for (const auto & ubatch : ubatches) {
419
- // if (!find_slot(ubatch)) {
420
- // success = false;
421
- // break;
422
- // }
423
- // }
409
+ for (const auto & ubatch : ubatches) {
410
+ if (!find_slot (ubatch)) {
411
+ success = false ;
412
+ break ;
413
+ }
414
+ }
424
415
425
416
// restore the original state
426
417
cells = std::move (org_cells);
@@ -431,14 +422,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
431
422
}
432
423
433
424
bool llama_kv_cache_recurrent::find_slot (const llama_ubatch & ubatch) {
434
- const uint32_t n_tokens = ubatch.n_tokens ;
435
- const uint32_t n_seqs = ubatch.n_seqs ;
425
+ const uint32_t n_seqs = ubatch.n_seqs ;
436
426
437
427
const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
438
428
439
429
// if we have enough unused cells before the current head ->
440
430
// better to start searching from the beginning of the cache, hoping to fill it
441
- if (head > used + 2 *n_tokens ) {
431
+ if (head > used + 2 *n_seqs ) {
442
432
head = 0 ;
443
433
}
444
434
@@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
534
524
empty_cell.src = orig_cell.src ;
535
525
orig_cell.seq_id .erase (seq_id);
536
526
empty_cell.seq_id .insert (seq_id); // will be overwritten
527
+ GGML_ASSERT (!orig_cell.is_empty ()); // has at least one remaining seq_id
537
528
}
538
529
seq_meta.tail = next_empty_cell;
539
530
// find next empty cell
540
531
if (s + 1 < n_seqs) {
541
- next_empty_cell += 1 ;
542
532
for (uint32_t i = 0 ; i < size; ++i) {
533
+ next_empty_cell += 1 ;
543
534
if (next_empty_cell >= size) { next_empty_cell -= size; }
544
535
kv_cell & cell = cells[next_empty_cell];
545
536
if (cell.is_empty ()) { break ; }
546
- next_empty_cell += 1 ;
547
537
}
548
538
}
549
539
}
@@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
553
543
554
544
// gather and re-order
555
545
for (uint32_t s = 0 ; s < n_seqs; ++s) {
556
- int32_t dst_id = s + min;
557
- int32_t src_id = cells[ubatch.seq_id [s][0 ]].tail ;
546
+ const int32_t dst_id = s + min;
547
+ const int32_t src_id = cells[ubatch.seq_id [s][0 ]].tail ;
558
548
if (dst_id != src_id) {
559
549
kv_cell & dst_cell = cells[dst_id];
560
550
kv_cell & src_cell = cells[src_id];
@@ -563,20 +553,22 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
563
553
std::swap (dst_cell.src , src_cell.src );
564
554
std::swap (dst_cell.seq_id , src_cell.seq_id );
565
555
566
- // swap tails (assuming they NEVER overlap)
567
- for (const llama_seq_id seq_id : src_cell.seq_id ) {
568
- cells[seq_id].tail = src_id;
569
- }
570
- for (const llama_seq_id seq_id : dst_cell.seq_id ) {
571
- cells[seq_id].tail = dst_id;
556
+ // swap tails
557
+ for (uint32_t i = 0 ; i < size; ++i) {
558
+ int32_t & tail = cells[i].tail ;
559
+ if (tail == src_id) {
560
+ tail = dst_id;
561
+ } else if (tail == dst_id) {
562
+ tail = src_id;
563
+ }
572
564
}
573
565
}
574
566
}
575
567
576
568
// update the pos of the used seqs
577
569
for (uint32_t s = 0 ; s < n_seqs; ++s) {
578
570
const llama_pos last_pos = ubatch.pos [n_seq_tokens * s + n_seq_tokens - 1 ];
579
- int32_t cell_id = s + min;
571
+ const int32_t cell_id = s + min;
580
572
kv_cell & cell = cells[cell_id];
581
573
582
574
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
594
586
}
595
587
}
596
588
589
+ // Find first cell without src refs, to use as the zero-ed state
590
+ {
591
+ // TODO: bake-in src refcounts in the cell metadata
592
+ std::vector<int32_t > refcounts (size, 0 );
593
+ for (size_t i = 0 ; i < size; ++i) {
594
+ const int32_t src = cells[i].src ;
595
+ if (src >= 0 ) {
596
+ refcounts[src] += 1 ;
597
+ }
598
+ }
599
+
600
+ rs_z = -1 ;
601
+ for (int i = min; i <= max; ++i) {
602
+ if (refcounts[i] == 0 ) {
603
+ rs_z = i;
604
+ break ;
605
+ }
606
+ }
607
+
608
+ for (int i = min; i <= max; ++i) {
609
+ if (cells[i].src < 0 ) {
610
+ GGML_ASSERT (rs_z >= 0 );
611
+ cells[i].src0 = rs_z;
612
+ } else {
613
+ // Stage the source ids for all used cells to allow correct seq_* behavior
614
+ // and still make these values available when setting the inputs
615
+ cells[i].src0 = cells[i].src ;
616
+ }
617
+ cells[i].src = i; // avoid moving or clearing twice
618
+ }
619
+ }
620
+
597
621
// allow getting the range of used cells, from head to head + n
598
622
head = min;
599
623
n = max - min + 1 ;
@@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
605
629
}
606
630
607
631
bool llama_kv_cache_recurrent::get_can_shift () const {
608
- return false ;
609
- }
610
-
611
- int32_t llama_kv_cache_recurrent::s_copy (int i) const {
612
- const uint32_t cell_id = i + head;
613
-
614
- // ////////////////////////////////////////////
615
- // TODO: this should not mutate the KV cache !
616
- kv_cell & cell = const_cast <kv_cell &>(cells[cell_id]);
617
-
618
- // prevent out-of-bound sources
619
- if (cell.src < 0 || (uint32_t ) cell.src >= size) {
620
- cell.src = cell_id;
621
- }
622
-
623
- int32_t res = cell.src ;
624
-
625
- // TODO: do not mutate the KV cache
626
- // ensure copy only happens once
627
- if (cell.src != (int32_t ) cell_id) {
628
- cell.src = cell_id;
629
- }
630
-
631
- return res;
632
- }
633
-
634
- float llama_kv_cache_recurrent::s_mask (int i) const {
635
- const uint32_t cell_id = i + head;
636
-
637
- // ////////////////////////////////////////////
638
- // TODO: this should not mutate the KV cache !
639
- kv_cell & cell = const_cast <kv_cell &>(cells[cell_id]);
640
-
641
- float res = (float ) (cell.src >= 0 );
642
-
643
- // only clear once
644
- if (cell.src < 0 ) {
645
- cell.src = cell_id;
646
- }
647
-
648
- return res;
632
+ // shifting the pos is trivial for recurrent models
633
+ return true ;
649
634
}
650
635
651
636
size_t llama_kv_cache_recurrent::total_size () const {
@@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
1111
1096
return is_full ? 0 : kv->head ;
1112
1097
}
1113
1098
1099
+ int32_t llama_kv_cache_recurrent_state::get_rs_z () const {
1100
+ return is_full ? 0 : kv->rs_z ;
1101
+ }
1102
+
1114
1103
uint32_t llama_kv_cache_recurrent_state::get_size () const {
1115
1104
return kv->size ;
1116
1105
}
@@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
1124
1113
}
1125
1114
1126
1115
int32_t llama_kv_cache_recurrent_state::s_copy (int i) const {
1127
- return kv->s_copy (i);
1128
- }
1129
-
1130
- float llama_kv_cache_recurrent_state::s_mask (int i) const {
1131
- return kv->s_mask (i);
1116
+ return kv->cells [i + kv->head ].src0 ;
1132
1117
}
0 commit comments