@@ -152,8 +152,6 @@ void llama_kv_cache_unified::clear() {
152
152
for (int32_t i = 0 ; i < (int32_t ) size; ++i) {
153
153
cells[i].pos = -1 ;
154
154
cells[i].seq_id .clear ();
155
- cells[i].src = -1 ;
156
- cells[i].tail = -1 ;
157
155
}
158
156
head = 0 ;
159
157
used = 0 ;
@@ -190,7 +188,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
190
188
}
191
189
192
190
cells[i].pos = -1 ;
193
- cells[i].src = -1 ;
194
191
195
192
if (new_head == size) {
196
193
new_head = i;
@@ -245,7 +242,6 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
245
242
}
246
243
247
244
cells[i].pos = -1 ;
248
- cells[i].src = -1 ;
249
245
cells[i].seq_id .clear ();
250
246
251
247
if (new_head == size){
@@ -380,7 +376,6 @@ void llama_kv_cache_unified::restore() {
380
376
}
381
377
382
378
cells[i].pos = -1 ;
383
- cells[i].src = -1 ;
384
379
}
385
380
386
381
new_head = std::min (new_head, range.c0 );
@@ -847,7 +842,7 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
847
842
848
843
uint32_t llama_kv_cache_unified::cell_max () const {
849
844
for (uint32_t i = size; i > 0 ; --i) {
850
- const llama_kv_cell & cell = cells[i - 1 ];
845
+ const kv_cell & cell = cells[i - 1 ];
851
846
852
847
if (cell.pos >= 0 && !cell.is_empty ()) {
853
848
return i;
@@ -983,7 +978,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
983
978
cells[i0 + nf] = cell1;
984
979
985
980
// clear the old cell and move the head there
986
- cell1 = llama_kv_cell ();
981
+ cell1 = kv_cell ();
987
982
head = n_used;
988
983
989
984
if (!cont) {
@@ -1226,7 +1221,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1226
1221
clear ();
1227
1222
1228
1223
for (uint32_t i = 0 ; i < cell_count; ++i) {
1229
- llama_kv_cell & cell = cells[i];
1224
+ kv_cell & cell = cells[i];
1230
1225
1231
1226
llama_pos pos;
1232
1227
uint32_t n_seq_id;
@@ -1538,7 +1533,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
1538
1533
if (0 <= seq_id) {
1539
1534
int32_t & tail_id = cells[seq_id].tail ;
1540
1535
if (tail_id >= 0 ) {
1541
- const llama_kv_cell & cell = cells[tail_id];
1536
+ const kv_cell & cell = cells[tail_id];
1542
1537
// partial intersection is invalid
1543
1538
if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1544
1539
return false ;
@@ -1572,23 +1567,22 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
1572
1567
}
1573
1568
1574
1569
if ((uint32_t ) seq_id_dst < size && (uint32_t ) seq_id_src < size) {
1575
- llama_kv_cell & tail_src = cells[seq_id_src];
1576
- llama_kv_cell & tail_dst = cells[seq_id_dst];
1570
+ kv_cell & tail_src = cells[seq_id_src];
1571
+ kv_cell & tail_dst = cells[seq_id_dst];
1577
1572
if (tail_dst.tail >= 0 ) {
1578
1573
// clear destination seq_id if it wasn't empty
1579
- llama_kv_cell & cell_dst = cells[tail_dst.tail ];
1574
+ kv_cell & cell_dst = cells[tail_dst.tail ];
1580
1575
1581
1576
cell_dst.seq_id .erase (seq_id_dst);
1582
1577
tail_dst.tail = -1 ;
1583
1578
if (cell_dst.seq_id .empty ()) {
1584
1579
cell_dst.pos = -1 ;
1585
- cell_dst.delta = -1 ;
1586
1580
cell_dst.src = -1 ;
1587
1581
used -= 1 ;
1588
1582
}
1589
1583
}
1590
1584
if (tail_src.tail >= 0 ) {
1591
- llama_kv_cell & cell_src = cells[tail_src.tail ];
1585
+ kv_cell & cell_src = cells[tail_src.tail ];
1592
1586
1593
1587
cell_src.seq_id .insert (seq_id_dst);
1594
1588
tail_dst.tail = tail_src.tail ;
@@ -1650,7 +1644,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
1650
1644
if (0 <= seq_id && seq_id < (int64_t ) size) {
1651
1645
const int32_t tail_id = cells[seq_id].tail ;
1652
1646
if (tail_id >= 0 ) {
1653
- llama_kv_cell & cell = cells[tail_id];
1647
+ kv_cell & cell = cells[tail_id];
1654
1648
if (cell.has_seq_id (seq_id) && p0 <= cell.pos && cell.pos < p1) {
1655
1649
cell.pos += delta;
1656
1650
}
@@ -1680,7 +1674,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
1680
1674
if (0 <= seq_id && seq_id < (int64_t ) size) {
1681
1675
const int32_t tail_id = cells[seq_id].tail ;
1682
1676
if (tail_id >= 0 ) {
1683
- llama_kv_cell & cell = cells[tail_id];
1677
+ kv_cell & cell = cells[tail_id];
1684
1678
if (cell.has_seq_id (seq_id) && p0 <= cell.pos && cell.pos < p1) {
1685
1679
cell.pos /= d;
1686
1680
}
@@ -1731,19 +1725,19 @@ int32_t llama_kv_cache_recurrent::s_copy(int i) const {
1731
1725
1732
1726
// ////////////////////////////////////////////
1733
1727
// TODO: this should not mutate the KV cache !
1734
- llama_kv_cell & kv_cell = const_cast <llama_kv_cell &>(cells[i]);
1728
+ kv_cell & cell = const_cast <kv_cell &>(cells[i]);
1735
1729
1736
1730
// prevent out-of-bound sources
1737
- if (kv_cell .src < 0 || (uint32_t ) kv_cell .src >= size) {
1738
- kv_cell .src = cell_id;
1731
+ if (cell .src < 0 || (uint32_t ) cell .src >= size) {
1732
+ cell .src = cell_id;
1739
1733
}
1740
1734
1741
- int32_t res = kv_cell .src ;
1735
+ int32_t res = cell .src ;
1742
1736
1743
1737
// TODO: do not mutate the KV cache
1744
1738
// ensure copy only happens once
1745
- if (kv_cell .src != (int32_t ) cell_id) {
1746
- kv_cell .src = cell_id;
1739
+ if (cell .src != (int32_t ) cell_id) {
1740
+ cell .src = cell_id;
1747
1741
}
1748
1742
1749
1743
return res;
@@ -1754,13 +1748,13 @@ float llama_kv_cache_recurrent::s_mask(int i) const {
1754
1748
1755
1749
// ////////////////////////////////////////////
1756
1750
// TODO: this should not mutate the KV cache !
1757
- llama_kv_cell & kv_cell = const_cast <llama_kv_cell &>(cells[i]);
1751
+ kv_cell & cell = const_cast <kv_cell &>(cells[i]);
1758
1752
1759
- float res = (float ) (kv_cell .src >= 0 );
1753
+ float res = (float ) (cell .src >= 0 );
1760
1754
1761
1755
// only clear once
1762
- if (kv_cell .src < 0 ) {
1763
- kv_cell .src = cell_id;
1756
+ if (cell .src < 0 ) {
1757
+ cell .src = cell_id;
1764
1758
}
1765
1759
1766
1760
return res;
@@ -1802,9 +1796,9 @@ bool llama_kv_cache_recurrent::find_slot(
1802
1796
return false ;
1803
1797
}
1804
1798
if (j > 0 ) {
1805
- llama_kv_cell & seq = cells[seq_id];
1799
+ kv_cell & seq = cells[seq_id];
1806
1800
if (seq.tail >= 0 ) {
1807
- llama_kv_cell & cell = cells[seq.tail ];
1801
+ kv_cell & cell = cells[seq.tail ];
1808
1802
// clear cells from seq_ids that become shared
1809
1803
// (should not normally happen, but let's handle it anyway)
1810
1804
cell.seq_id .erase (seq_id);
@@ -1824,7 +1818,7 @@ bool llama_kv_cache_recurrent::find_slot(
1824
1818
std::vector<int32_t > tails_verif;
1825
1819
tails_verif.assign (size, -1 );
1826
1820
for (uint32_t i = 0 ; i < size; ++i) {
1827
- llama_kv_cell & cell = cells[i];
1821
+ kv_cell & cell = cells[i];
1828
1822
for (llama_seq_id seq_id : cell.seq_id ) {
1829
1823
if (tails_verif[seq_id] != -1 ) {
1830
1824
LLAMA_LOG_ERROR (" %s: duplicate tail for seq_id %d in cell %d and %d\n " , __func__, seq_id, i, tails_verif[seq_id]);
@@ -1845,28 +1839,28 @@ bool llama_kv_cache_recurrent::find_slot(
1845
1839
1846
1840
for (uint32_t i = 0 ; i < size; ++i) {
1847
1841
if (next_empty_cell >= size) { next_empty_cell -= size; }
1848
- llama_kv_cell & cell = cells[next_empty_cell];
1842
+ kv_cell & cell = cells[next_empty_cell];
1849
1843
if (cell.is_empty ()) { break ; }
1850
1844
next_empty_cell += 1 ;
1851
1845
}
1852
1846
1853
1847
// find usable cell range
1854
1848
for (uint32_t s = 0 ; s < n_seqs; ++s) {
1855
1849
const llama_seq_id seq_id = ubatch.seq_id [s][0 ];
1856
- llama_kv_cell & seq_meta = cells[seq_id];
1850
+ kv_cell & seq_meta = cells[seq_id];
1857
1851
bool has_cell = false ;
1858
1852
if (seq_meta.tail >= 0 ) {
1859
- llama_kv_cell & cell = cells[seq_meta.tail ];
1853
+ kv_cell & cell = cells[seq_meta.tail ];
1860
1854
GGML_ASSERT (cell.has_seq_id (seq_id));
1861
1855
// does this seq_id "own" the cell?
1862
1856
if (cell.seq_id .size () == 1 ) { has_cell = true ; }
1863
1857
}
1864
1858
if (!has_cell) {
1865
- llama_kv_cell & empty_cell = cells[next_empty_cell];
1859
+ kv_cell & empty_cell = cells[next_empty_cell];
1866
1860
GGML_ASSERT (empty_cell.is_empty ());
1867
1861
// copy old tail into the empty cell
1868
1862
if (seq_meta.tail >= 0 ) {
1869
- llama_kv_cell & orig_cell = cells[seq_meta.tail ];
1863
+ kv_cell & orig_cell = cells[seq_meta.tail ];
1870
1864
empty_cell.pos = orig_cell.pos ;
1871
1865
empty_cell.src = orig_cell.src ;
1872
1866
orig_cell.seq_id .erase (seq_id);
@@ -1878,7 +1872,7 @@ bool llama_kv_cache_recurrent::find_slot(
1878
1872
next_empty_cell += 1 ;
1879
1873
for (uint32_t i = 0 ; i < size; ++i) {
1880
1874
if (next_empty_cell >= size) { next_empty_cell -= size; }
1881
- llama_kv_cell & cell = cells[next_empty_cell];
1875
+ kv_cell & cell = cells[next_empty_cell];
1882
1876
if (cell.is_empty ()) { break ; }
1883
1877
next_empty_cell += 1 ;
1884
1878
}
@@ -1893,8 +1887,8 @@ bool llama_kv_cache_recurrent::find_slot(
1893
1887
int32_t dst_id = s + min;
1894
1888
int32_t src_id = cells[ubatch.seq_id [s][0 ]].tail ;
1895
1889
if (dst_id != src_id) {
1896
- llama_kv_cell & dst_cell = cells[dst_id];
1897
- llama_kv_cell & src_cell = cells[src_id];
1890
+ kv_cell & dst_cell = cells[dst_id];
1891
+ kv_cell & src_cell = cells[src_id];
1898
1892
1899
1893
std::swap (dst_cell.pos , src_cell.pos );
1900
1894
std::swap (dst_cell.src , src_cell.src );
@@ -1914,7 +1908,7 @@ bool llama_kv_cache_recurrent::find_slot(
1914
1908
for (uint32_t s = 0 ; s < n_seqs; ++s) {
1915
1909
const llama_pos last_pos = ubatch.pos [n_seq_tokens * s + n_seq_tokens - 1 ];
1916
1910
int32_t cell_id = s + min;
1917
- llama_kv_cell & cell = cells[cell_id];
1911
+ kv_cell & cell = cells[cell_id];
1918
1912
1919
1913
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
1920
1914
// What should happen when the pos backtracks or skips a value?
@@ -1935,7 +1929,7 @@ bool llama_kv_cache_recurrent::find_slot(
1935
1929
head = min;
1936
1930
n = max - min + 1 ;
1937
1931
used = std::count_if (cells.begin (), cells.end (),
1938
- [](const llama_kv_cell & cell){ return !cell.is_empty (); });
1932
+ [](const kv_cell & cell){ return !cell.is_empty (); });
1939
1933
1940
1934
// sanity check
1941
1935
return n >= n_seqs;
@@ -1958,7 +1952,7 @@ llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32
1958
1952
1959
1953
uint32_t llama_kv_cache_recurrent::cell_max () const {
1960
1954
for (uint32_t i = size; i > 0 ; --i) {
1961
- const llama_kv_cell & cell = cells[i - 1 ];
1955
+ const kv_cell & cell = cells[i - 1 ];
1962
1956
1963
1957
if (cell.pos >= 0 && !cell.is_empty ()) {
1964
1958
return i;
@@ -2200,7 +2194,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
2200
2194
clear ();
2201
2195
2202
2196
for (uint32_t i = 0 ; i < cell_count; ++i) {
2203
- llama_kv_cell & cell = cells[i];
2197
+ kv_cell & cell = cells[i];
2204
2198
2205
2199
llama_pos pos;
2206
2200
uint32_t n_seq_id;
@@ -2412,7 +2406,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache
2412
2406
view->cells_sequences = (llama_seq_id *)p;
2413
2407
}
2414
2408
2415
- const std::vector<llama_kv_cell > & kv_cells = kvu->cells ;
2409
+ const std::vector<llama_kv_cache_unified::kv_cell > & kv_cells = kvu->cells ;
2416
2410
llama_kv_cache_view_cell * c_curr = view->cells ;
2417
2411
llama_seq_id * cs_curr = view->cells_sequences ;
2418
2412
int32_t used_cells = 0 ;
0 commit comments