@@ -334,13 +334,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
334
334
ubatches.push_back (std::move (ubatch)); // NOLINT
335
335
}
336
336
337
- auto heads = prepare (ubatches);
338
- if (heads .empty ()) {
337
+ auto sinfos = prepare (ubatches);
338
+ if (sinfos .empty ()) {
339
339
break ;
340
340
}
341
341
342
342
return std::make_unique<llama_kv_cache_unified_context>(
343
- this , std::move (heads ), std::move (ubatches));
343
+ this , std::move (sinfos ), std::move (ubatches));
344
344
} while (false );
345
345
346
346
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -383,8 +383,8 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
383
383
return std::make_unique<llama_kv_cache_unified_context>(this , lctx, do_shift, std::move (dinfo));
384
384
}
385
385
386
- llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare (const std::vector<llama_ubatch> & ubatches) {
387
- llama_kv_cache_unified::ubatch_heads res;
386
+ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare (const std::vector<llama_ubatch> & ubatches) {
387
+ llama_kv_cache_unified::slot_info_vec_t res;
388
388
389
389
struct state {
390
390
uint32_t head_old; // old position of the head, before placing the ubatch
@@ -400,20 +400,25 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
400
400
401
401
for (const auto & ubatch : ubatches) {
402
402
// only find a suitable slot for the ubatch. don't modify the cells yet
403
- const int32_t head_new = find_slot (ubatch);
404
- if (head_new < 0 ) {
403
+ const auto sinfo_new = find_slot (ubatch);
404
+ if (sinfo_new. empty () ) {
405
405
success = false ;
406
406
break ;
407
407
}
408
408
409
409
// remeber the position that we found
410
- res.push_back (head_new);
410
+ res.push_back (sinfo_new);
411
+
412
+ // TODO: temporary
413
+ if (supports_set_rows) {
414
+ GGML_ASSERT (sinfo_new.is_cont ());
415
+ }
411
416
412
417
// store the old state of the cells in the recovery stack
413
- states.push_back ({head, ( uint32_t ) head_new , cells.cp (head_new , ubatch.n_tokens )});
418
+ states.push_back ({head, sinfo_new. head () , cells.cp (sinfo_new. head () , ubatch.n_tokens )});
414
419
415
420
// now emplace the ubatch
416
- apply_ubatch (head_new , ubatch);
421
+ apply_ubatch (sinfo_new , ubatch);
417
422
}
418
423
419
424
// iterate backwards and restore the cells to their original state
@@ -520,7 +525,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
520
525
return updated;
521
526
}
522
527
523
- int32_t llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch) const {
528
+ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch) const {
524
529
const uint32_t n_tokens = ubatch.n_tokens ;
525
530
526
531
uint32_t head_cur = this ->head ;
@@ -533,7 +538,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
533
538
534
539
if (n_tokens > cells.size ()) {
535
540
LLAMA_LOG_ERROR (" %s: n_tokens = %d > size = %u\n " , __func__, n_tokens, cells.size ());
536
- return - 1 ;
541
+ return { } ;
537
542
}
538
543
539
544
if (debug > 0 ) {
@@ -649,37 +654,48 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
649
654
650
655
if (n_tested >= cells.size ()) {
651
656
// LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
652
- return - 1 ;
657
+ return { } ;
653
658
}
654
659
}
655
660
656
- return head_cur;
661
+ slot_info res;
662
+
663
+ res.idxs .resize (n_tokens);
664
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
665
+ res.idxs [i] = head_cur + i;
666
+ }
667
+
668
+ return res;
657
669
}
658
670
659
- void llama_kv_cache_unified::apply_ubatch (uint32_t head_cur , const llama_ubatch & ubatch) {
671
+ void llama_kv_cache_unified::apply_ubatch (const slot_info & sinfo , const llama_ubatch & ubatch) {
660
672
// keep track of the max sequence position that we would overwrite with this ubatch
661
673
// for non-SWA cache, this would be always empty
662
674
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
663
675
for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
664
676
seq_pos_max_rm[s] = -1 ;
665
677
}
666
678
679
+ assert (ubatch.n_tokens == sinfo.idxs .size ());
680
+
667
681
for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
668
- if (!cells.is_empty (head_cur + i)) {
669
- assert (cells.seq_count (head_cur + i) == 1 );
682
+ const auto idx = sinfo.idxs [i];
683
+
684
+ if (!cells.is_empty (idx)) {
685
+ assert (cells.seq_count (idx) == 1 );
670
686
671
- const llama_seq_id seq_id = cells.seq_get (head_cur + i );
672
- const llama_pos pos = cells.pos_get (head_cur + i );
687
+ const llama_seq_id seq_id = cells.seq_get (idx );
688
+ const llama_pos pos = cells.pos_get (idx );
673
689
674
690
seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
675
691
676
- cells.rm (head_cur + i );
692
+ cells.rm (idx );
677
693
}
678
694
679
- cells.pos_set (head_cur + i , ubatch.pos [i]);
695
+ cells.pos_set (idx , ubatch.pos [i]);
680
696
681
697
for (int32_t s = 0 ; s < ubatch.n_seq_id [i]; s++) {
682
- cells.seq_add (head_cur + i , ubatch.seq_id [i][s]);
698
+ cells.seq_add (idx , ubatch.seq_id [i][s]);
683
699
}
684
700
}
685
701
@@ -700,7 +716,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
700
716
}
701
717
702
718
// move the head at the end of the slot
703
- head = head_cur + ubatch. n_tokens ;
719
+ head = sinfo. idxs . back () + 1 ;
704
720
}
705
721
706
722
bool llama_kv_cache_unified::get_can_shift () const {
@@ -753,7 +769,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
753
769
0 );
754
770
}
755
771
756
- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur ) const {
772
+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo ) const {
757
773
const int32_t ikv = map_layer_ids.at (il);
758
774
759
775
auto * k = layers[ikv].k ;
@@ -772,12 +788,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
772
788
773
789
ggml_tensor * k_view = ggml_view_1d (ctx, k,
774
790
n_tokens*n_embd_k_gqa,
775
- ggml_row_size (k->type , n_embd_k_gqa)*head_cur );
791
+ ggml_row_size (k->type , n_embd_k_gqa)*sinfo. head () );
776
792
777
793
return ggml_cpy (ctx, k_cur, k_view);
778
794
}
779
795
780
- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur ) const {
796
+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo ) const {
781
797
const int32_t ikv = map_layer_ids.at (il);
782
798
783
799
auto * v = layers[ikv].v ;
@@ -814,19 +830,19 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
814
830
if (!v_trans) {
815
831
v_view = ggml_view_1d (ctx, v,
816
832
n_tokens*n_embd_v_gqa,
817
- ggml_row_size (v->type , n_embd_v_gqa)*head_cur );
833
+ ggml_row_size (v->type , n_embd_v_gqa)*sinfo. head () );
818
834
} else {
819
835
v_cur = ggml_transpose (ctx, v_cur);
820
836
821
837
v_view = ggml_view_2d (ctx, v, n_tokens, n_embd_v_gqa,
822
- (v->ne [1 ])*ggml_element_size (v),
823
- (head_cur )*ggml_element_size (v));
838
+ (v->ne [1 ] )*ggml_element_size (v),
839
+ (sinfo. head () )*ggml_element_size (v));
824
840
}
825
841
826
842
return ggml_cpy (ctx, v_cur, v_view);
827
843
}
828
844
829
- void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur ) const {
845
+ void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo ) const {
830
846
if (!supports_set_rows) {
831
847
return ;
832
848
}
@@ -837,7 +853,7 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
837
853
int64_t * data = (int64_t *) dst->data ;
838
854
839
855
for (int64_t i = 0 ; i < n_tokens; ++i) {
840
- data[i] = head_cur + i ;
856
+ data[i] = sinfo. idxs [i] ;
841
857
}
842
858
}
843
859
@@ -1580,13 +1596,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1580
1596
ubatch.seq_id [i] = &dest_seq_id;
1581
1597
}
1582
1598
1583
- const auto head_cur = find_slot (ubatch);
1584
- if (head_cur < 0 ) {
1599
+ const auto sinfo = find_slot (ubatch);
1600
+ if (sinfo. empty () ) {
1585
1601
LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
1586
1602
return false ;
1587
1603
}
1588
1604
1589
- apply_ubatch (head_cur, ubatch);
1605
+ apply_ubatch (sinfo, ubatch);
1606
+
1607
+ const auto head_cur = sinfo.head ();
1590
1608
1591
1609
// keep the head at the old position because we will read the KV data into it in state_read_data()
1592
1610
head = head_cur;
@@ -1772,7 +1790,10 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
1772
1790
llama_kv_cache_unified_context::llama_kv_cache_unified_context (
1773
1791
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1774
1792
n_kv = kv->get_size ();
1775
- head = 0 ;
1793
+
1794
+ sinfos.resize (1 );
1795
+ sinfos[0 ].idxs .resize (1 );
1796
+ sinfos[0 ].idxs [0 ] = 0 ;
1776
1797
}
1777
1798
1778
1799
llama_kv_cache_unified_context::llama_kv_cache_unified_context (
@@ -1787,16 +1808,16 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1787
1808
1788
1809
llama_kv_cache_unified_context::llama_kv_cache_unified_context (
1789
1810
llama_kv_cache_unified * kv,
1790
- llama_kv_cache_unified::ubatch_heads heads ,
1791
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads (std::move(heads )), ubatches(std::move(ubatches)) {
1811
+ llama_kv_cache_unified::slot_info_vec_t sinfos ,
1812
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos (std::move(sinfos )), ubatches(std::move(ubatches)) {
1792
1813
}
1793
1814
1794
1815
llama_kv_cache_unified_context::~llama_kv_cache_unified_context () = default ;
1795
1816
1796
1817
bool llama_kv_cache_unified_context::next () {
1797
1818
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
1798
1819
1799
- if (++i_next >= ubatches.size ()) {
1820
+ if (++i_cur >= ubatches.size ()) {
1800
1821
return false ;
1801
1822
}
1802
1823
@@ -1813,10 +1834,9 @@ bool llama_kv_cache_unified_context::apply() {
1813
1834
return true ;
1814
1835
}
1815
1836
1816
- kv->apply_ubatch (heads[i_next ], ubatches[i_next ]);
1837
+ kv->apply_ubatch (sinfos[i_cur ], ubatches[i_cur ]);
1817
1838
1818
1839
n_kv = kv->get_n_kv ();
1819
- head = heads[i_next];
1820
1840
1821
1841
return true ;
1822
1842
}
@@ -1828,7 +1848,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
1828
1848
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch () const {
1829
1849
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
1830
1850
1831
- return ubatches[i_next ];
1851
+ return ubatches[i_cur ];
1832
1852
}
1833
1853
1834
1854
uint32_t llama_kv_cache_unified_context::get_n_kv () const {
@@ -1844,19 +1864,19 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
1844
1864
}
1845
1865
1846
1866
ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1847
- return kv->cpy_k (ctx, k_cur, kv_idxs, il, head );
1867
+ return kv->cpy_k (ctx, k_cur, kv_idxs, il, sinfos[i_cur] );
1848
1868
}
1849
1869
1850
1870
ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1851
- return kv->cpy_v (ctx, v_cur, kv_idxs, il, head );
1871
+ return kv->cpy_v (ctx, v_cur, kv_idxs, il, sinfos[i_cur] );
1852
1872
}
1853
1873
1854
1874
void llama_kv_cache_unified_context::set_input_k_shift (ggml_tensor * dst) const {
1855
1875
kv->set_input_k_shift (dst);
1856
1876
}
1857
1877
1858
1878
void llama_kv_cache_unified_context::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1859
- kv->set_input_kv_idxs (dst, ubatch, head );
1879
+ kv->set_input_kv_idxs (dst, ubatch, sinfos[i_cur] );
1860
1880
}
1861
1881
1862
1882
void llama_kv_cache_unified_context::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
0 commit comments