@@ -400,8 +400,11 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
400
400
bool success = true ;
401
401
402
402
for (const auto & ubatch : ubatches) {
403
+ // non-continuous slots require support for ggml_set_rows()
404
+ const bool cont = supports_set_rows ? false : true ;
405
+
403
406
// only find a suitable slot for the ubatch. don't modify the cells yet
404
- const auto sinfo_new = find_slot (ubatch);
407
+ const auto sinfo_new = find_slot (ubatch, cont );
405
408
if (sinfo_new.empty ()) {
406
409
success = false ;
407
410
break ;
@@ -521,7 +524,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
521
524
return updated;
522
525
}
523
526
524
- llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch) const {
527
+ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch, bool cont ) const {
525
528
const uint32_t n_tokens = ubatch.n_tokens ;
526
529
527
530
uint32_t head_cur = this ->head ;
@@ -595,17 +598,25 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
595
598
}
596
599
}
597
600
601
+ uint32_t n_found = 0 ;
598
602
uint32_t n_tested = 0 ;
599
603
604
+ const uint32_t n_test = cont ? n_tokens : 1 ;
605
+
606
+ slot_info res;
607
+
608
+ res.idxs .resize (n_tokens);
609
+
600
610
while (true ) {
601
- if (head_cur + n_tokens > cells.size ()) {
611
+ if (head_cur + n_test > cells.size ()) {
602
612
n_tested += cells.size () - head_cur;
603
613
head_cur = 0 ;
604
614
continue ;
605
615
}
606
616
607
- bool found = true ;
608
- for (uint32_t i = 0 ; i < n_tokens; i++) {
617
+ for (uint32_t i = 0 ; i < n_test; i++) {
618
+ const auto idx = head_cur;
619
+
609
620
// const llama_pos pos = ubatch.pos[i];
610
621
// const llama_seq_id seq_id = ubatch.seq_id[i][0];
611
622
@@ -615,19 +626,19 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
615
626
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
616
627
// - mask SWA, using current max pos for that sequence in the cache
617
628
// always insert in the cell with minimum pos
618
- bool can_use = cells.is_empty (head_cur + i );
629
+ bool can_use = cells.is_empty (idx );
619
630
620
- if (!can_use && cells.seq_count (head_cur + i ) == 1 ) {
621
- const llama_pos pos_cell = cells.pos_get (head_cur + i );
631
+ if (!can_use && cells.seq_count (idx ) == 1 ) {
632
+ const llama_pos pos_cell = cells.pos_get (idx );
622
633
623
634
// (disabled) causal mask
624
635
// note: it's better to purge any "future" tokens beforehand
625
- // if (cells.seq_has(head_cur + i , seq_id)) {
636
+ // if (cells.seq_has(idx , seq_id)) {
626
637
// can_use = pos_cell >= pos;
627
638
// }
628
639
629
640
if (!can_use) {
630
- const llama_seq_id seq_id_cell = cells.seq_get (head_cur + i );
641
+ const llama_seq_id seq_id_cell = cells.seq_get (idx );
631
642
632
643
// SWA mask
633
644
if (is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
@@ -636,29 +647,35 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
636
647
}
637
648
}
638
649
639
- if (!can_use) {
640
- found = false ;
641
- head_cur += i + 1 ;
642
- n_tested += i + 1 ;
650
+ head_cur++;
651
+ n_tested++;
652
+
653
+ if (can_use) {
654
+ res.idxs [n_found] = idx;
655
+
656
+ n_found++;
657
+ } else {
643
658
break ;
644
659
}
645
660
}
646
661
647
- if (found ) {
662
+ if (n_found == n_tokens ) {
648
663
break ;
649
664
}
650
665
666
+ if (cont) {
667
+ n_found = 0 ;
668
+ }
669
+
651
670
if (n_tested >= cells.size ()) {
652
671
// LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
653
672
return { };
654
673
}
655
674
}
656
675
657
- slot_info res;
658
-
659
- res.idxs .resize (n_tokens);
660
- for (uint32_t i = 0 ; i < n_tokens; ++i) {
661
- res.idxs [i] = head_cur + i;
676
+ // we didn't find a suitable slot - return empty result
677
+ if (n_found < n_tokens) {
678
+ res.clear ();
662
679
}
663
680
664
681
return res;
@@ -1592,7 +1609,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1592
1609
ubatch.seq_id [i] = &dest_seq_id;
1593
1610
}
1594
1611
1595
- const auto sinfo = find_slot (ubatch);
1612
+ const auto sinfo = find_slot (ubatch, true );
1596
1613
if (sinfo.empty ()) {
1597
1614
LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
1598
1615
return false ;
0 commit comments