Skip to content

Commit 6b85160

Browse files
committed
kv-cache : use separate KV cell structs for unified/recurrent
ggml-ci
1 parent 6c01514 commit 6b85160

File tree

2 files changed

+77
-67
lines changed

2 files changed

+77
-67
lines changed

src/llama-kv-cache.cpp

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,6 @@ void llama_kv_cache_unified::clear() {
152152
for (int32_t i = 0; i < (int32_t) size; ++i) {
153153
cells[i].pos = -1;
154154
cells[i].seq_id.clear();
155-
cells[i].src = -1;
156-
cells[i].tail = -1;
157155
}
158156
head = 0;
159157
used = 0;
@@ -190,7 +188,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
190188
}
191189

192190
cells[i].pos = -1;
193-
cells[i].src = -1;
194191

195192
if (new_head == size) {
196193
new_head = i;
@@ -245,7 +242,6 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
245242
}
246243

247244
cells[i].pos = -1;
248-
cells[i].src = -1;
249245
cells[i].seq_id.clear();
250246

251247
if (new_head == size){
@@ -380,7 +376,6 @@ void llama_kv_cache_unified::restore() {
380376
}
381377

382378
cells[i].pos = -1;
383-
cells[i].src = -1;
384379
}
385380

386381
new_head = std::min(new_head, range.c0);
@@ -847,7 +842,7 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
847842

848843
uint32_t llama_kv_cache_unified::cell_max() const {
849844
for (uint32_t i = size; i > 0; --i) {
850-
const llama_kv_cell & cell = cells[i - 1];
845+
const kv_cell & cell = cells[i - 1];
851846

852847
if (cell.pos >= 0 && !cell.is_empty()) {
853848
return i;
@@ -983,7 +978,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
983978
cells[i0 + nf] = cell1;
984979

985980
// clear the old cell and move the head there
986-
cell1 = llama_kv_cell();
981+
cell1 = kv_cell();
987982
head = n_used;
988983

989984
if (!cont) {
@@ -1226,7 +1221,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
12261221
clear();
12271222

12281223
for (uint32_t i = 0; i < cell_count; ++i) {
1229-
llama_kv_cell & cell = cells[i];
1224+
kv_cell & cell = cells[i];
12301225

12311226
llama_pos pos;
12321227
uint32_t n_seq_id;
@@ -1538,7 +1533,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
15381533
if (0 <= seq_id) {
15391534
int32_t & tail_id = cells[seq_id].tail;
15401535
if (tail_id >= 0) {
1541-
const llama_kv_cell & cell = cells[tail_id];
1536+
const kv_cell & cell = cells[tail_id];
15421537
// partial intersection is invalid
15431538
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
15441539
return false;
@@ -1572,23 +1567,22 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
15721567
}
15731568

15741569
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
1575-
llama_kv_cell & tail_src = cells[seq_id_src];
1576-
llama_kv_cell & tail_dst = cells[seq_id_dst];
1570+
kv_cell & tail_src = cells[seq_id_src];
1571+
kv_cell & tail_dst = cells[seq_id_dst];
15771572
if (tail_dst.tail >= 0) {
15781573
// clear destination seq_id if it wasn't empty
1579-
llama_kv_cell & cell_dst = cells[tail_dst.tail];
1574+
kv_cell & cell_dst = cells[tail_dst.tail];
15801575

15811576
cell_dst.seq_id.erase(seq_id_dst);
15821577
tail_dst.tail = -1;
15831578
if (cell_dst.seq_id.empty()) {
15841579
cell_dst.pos = -1;
1585-
cell_dst.delta = -1;
15861580
cell_dst.src = -1;
15871581
used -= 1;
15881582
}
15891583
}
15901584
if (tail_src.tail >= 0) {
1591-
llama_kv_cell & cell_src = cells[tail_src.tail];
1585+
kv_cell & cell_src = cells[tail_src.tail];
15921586

15931587
cell_src.seq_id.insert(seq_id_dst);
15941588
tail_dst.tail = tail_src.tail;
@@ -1650,7 +1644,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
16501644
if (0 <= seq_id && seq_id < (int64_t) size) {
16511645
const int32_t tail_id = cells[seq_id].tail;
16521646
if (tail_id >= 0) {
1653-
llama_kv_cell & cell = cells[tail_id];
1647+
kv_cell & cell = cells[tail_id];
16541648
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
16551649
cell.pos += delta;
16561650
}
@@ -1680,7 +1674,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
16801674
if (0 <= seq_id && seq_id < (int64_t) size) {
16811675
const int32_t tail_id = cells[seq_id].tail;
16821676
if (tail_id >= 0) {
1683-
llama_kv_cell & cell = cells[tail_id];
1677+
kv_cell & cell = cells[tail_id];
16841678
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
16851679
cell.pos /= d;
16861680
}
@@ -1731,19 +1725,19 @@ int32_t llama_kv_cache_recurrent::s_copy(int i) const {
17311725

17321726
//////////////////////////////////////////////
17331727
// TODO: this should not mutate the KV cache !
1734-
llama_kv_cell & kv_cell = const_cast<llama_kv_cell &>(cells[i]);
1728+
kv_cell & cell = const_cast<kv_cell &>(cells[i]);
17351729

17361730
// prevent out-of-bound sources
1737-
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= size) {
1738-
kv_cell.src = cell_id;
1731+
if (cell.src < 0 || (uint32_t) cell.src >= size) {
1732+
cell.src = cell_id;
17391733
}
17401734

1741-
int32_t res = kv_cell.src;
1735+
int32_t res = cell.src;
17421736

17431737
// TODO: do not mutate the KV cache
17441738
// ensure copy only happens once
1745-
if (kv_cell.src != (int32_t) cell_id) {
1746-
kv_cell.src = cell_id;
1739+
if (cell.src != (int32_t) cell_id) {
1740+
cell.src = cell_id;
17471741
}
17481742

17491743
return res;
@@ -1754,13 +1748,13 @@ float llama_kv_cache_recurrent::s_mask(int i) const {
17541748

17551749
//////////////////////////////////////////////
17561750
// TODO: this should not mutate the KV cache !
1757-
llama_kv_cell & kv_cell = const_cast<llama_kv_cell &>(cells[i]);
1751+
kv_cell & cell = const_cast<kv_cell &>(cells[i]);
17581752

1759-
float res = (float) (kv_cell.src >= 0);
1753+
float res = (float) (cell.src >= 0);
17601754

17611755
// only clear once
1762-
if (kv_cell.src < 0) {
1763-
kv_cell.src = cell_id;
1756+
if (cell.src < 0) {
1757+
cell.src = cell_id;
17641758
}
17651759

17661760
return res;
@@ -1802,9 +1796,9 @@ bool llama_kv_cache_recurrent::find_slot(
18021796
return false;
18031797
}
18041798
if (j > 0) {
1805-
llama_kv_cell & seq = cells[seq_id];
1799+
kv_cell & seq = cells[seq_id];
18061800
if (seq.tail >= 0) {
1807-
llama_kv_cell & cell = cells[seq.tail];
1801+
kv_cell & cell = cells[seq.tail];
18081802
// clear cells from seq_ids that become shared
18091803
// (should not normally happen, but let's handle it anyway)
18101804
cell.seq_id.erase(seq_id);
@@ -1824,7 +1818,7 @@ bool llama_kv_cache_recurrent::find_slot(
18241818
std::vector<int32_t> tails_verif;
18251819
tails_verif.assign(size, -1);
18261820
for (uint32_t i = 0; i < size; ++i) {
1827-
llama_kv_cell & cell = cells[i];
1821+
kv_cell & cell = cells[i];
18281822
for (llama_seq_id seq_id : cell.seq_id) {
18291823
if (tails_verif[seq_id] != -1) {
18301824
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -1845,28 +1839,28 @@ bool llama_kv_cache_recurrent::find_slot(
18451839

18461840
for (uint32_t i = 0; i < size; ++i) {
18471841
if (next_empty_cell >= size) { next_empty_cell -= size; }
1848-
llama_kv_cell & cell = cells[next_empty_cell];
1842+
kv_cell & cell = cells[next_empty_cell];
18491843
if (cell.is_empty()) { break; }
18501844
next_empty_cell += 1;
18511845
}
18521846

18531847
// find usable cell range
18541848
for (uint32_t s = 0; s < n_seqs; ++s) {
18551849
const llama_seq_id seq_id = ubatch.seq_id[s][0];
1856-
llama_kv_cell & seq_meta = cells[seq_id];
1850+
kv_cell & seq_meta = cells[seq_id];
18571851
bool has_cell = false;
18581852
if (seq_meta.tail >= 0) {
1859-
llama_kv_cell & cell = cells[seq_meta.tail];
1853+
kv_cell & cell = cells[seq_meta.tail];
18601854
GGML_ASSERT(cell.has_seq_id(seq_id));
18611855
// does this seq_id "own" the cell?
18621856
if (cell.seq_id.size() == 1) { has_cell = true; }
18631857
}
18641858
if (!has_cell) {
1865-
llama_kv_cell & empty_cell = cells[next_empty_cell];
1859+
kv_cell & empty_cell = cells[next_empty_cell];
18661860
GGML_ASSERT(empty_cell.is_empty());
18671861
// copy old tail into the empty cell
18681862
if (seq_meta.tail >= 0) {
1869-
llama_kv_cell & orig_cell = cells[seq_meta.tail];
1863+
kv_cell & orig_cell = cells[seq_meta.tail];
18701864
empty_cell.pos = orig_cell.pos;
18711865
empty_cell.src = orig_cell.src;
18721866
orig_cell.seq_id.erase(seq_id);
@@ -1878,7 +1872,7 @@ bool llama_kv_cache_recurrent::find_slot(
18781872
next_empty_cell += 1;
18791873
for (uint32_t i = 0; i < size; ++i) {
18801874
if (next_empty_cell >= size) { next_empty_cell -= size; }
1881-
llama_kv_cell & cell = cells[next_empty_cell];
1875+
kv_cell & cell = cells[next_empty_cell];
18821876
if (cell.is_empty()) { break; }
18831877
next_empty_cell += 1;
18841878
}
@@ -1893,8 +1887,8 @@ bool llama_kv_cache_recurrent::find_slot(
18931887
int32_t dst_id = s + min;
18941888
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
18951889
if (dst_id != src_id) {
1896-
llama_kv_cell & dst_cell = cells[dst_id];
1897-
llama_kv_cell & src_cell = cells[src_id];
1890+
kv_cell & dst_cell = cells[dst_id];
1891+
kv_cell & src_cell = cells[src_id];
18981892

18991893
std::swap(dst_cell.pos, src_cell.pos);
19001894
std::swap(dst_cell.src, src_cell.src);
@@ -1914,7 +1908,7 @@ bool llama_kv_cache_recurrent::find_slot(
19141908
for (uint32_t s = 0; s < n_seqs; ++s) {
19151909
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
19161910
int32_t cell_id = s + min;
1917-
llama_kv_cell & cell = cells[cell_id];
1911+
kv_cell & cell = cells[cell_id];
19181912

19191913
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
19201914
// What should happen when the pos backtracks or skips a value?
@@ -1935,7 +1929,7 @@ bool llama_kv_cache_recurrent::find_slot(
19351929
head = min;
19361930
n = max - min + 1;
19371931
used = std::count_if(cells.begin(), cells.end(),
1938-
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
1932+
[](const kv_cell & cell){ return !cell.is_empty(); });
19391933

19401934
// sanity check
19411935
return n >= n_seqs;
@@ -1958,7 +1952,7 @@ llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32
19581952

19591953
uint32_t llama_kv_cache_recurrent::cell_max() const {
19601954
for (uint32_t i = size; i > 0; --i) {
1961-
const llama_kv_cell & cell = cells[i - 1];
1955+
const kv_cell & cell = cells[i - 1];
19621956

19631957
if (cell.pos >= 0 && !cell.is_empty()) {
19641958
return i;
@@ -2200,7 +2194,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
22002194
clear();
22012195

22022196
for (uint32_t i = 0; i < cell_count; ++i) {
2203-
llama_kv_cell & cell = cells[i];
2197+
kv_cell & cell = cells[i];
22042198

22052199
llama_pos pos;
22062200
uint32_t n_seq_id;
@@ -2412,7 +2406,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache
24122406
view->cells_sequences = (llama_seq_id *)p;
24132407
}
24142408

2415-
const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
2409+
const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
24162410
llama_kv_cache_view_cell * c_curr = view->cells;
24172411
llama_seq_id * cs_curr = view->cells_sequences;
24182412
int32_t used_cells = 0;

src/llama-kv-cache.h

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -111,29 +111,6 @@ struct llama_kv_cache_guard {
111111
llama_kv_cache * kv;
112112
};
113113

114-
// TODO: create separate cells for unified/recurrent caches
115-
// TODO: move in the source file
116-
struct llama_kv_cell {
117-
llama_pos pos = -1;
118-
llama_pos delta = 0;
119-
int32_t src = -1; // used by recurrent state models to copy states
120-
int32_t tail = -1;
121-
122-
std::set<llama_seq_id> seq_id;
123-
124-
bool has_seq_id(const llama_seq_id & id) const {
125-
return seq_id.find(id) != seq_id.end();
126-
}
127-
128-
bool is_empty() const {
129-
return seq_id.empty();
130-
}
131-
132-
bool is_same_seq(const llama_kv_cell & other) const {
133-
return seq_id == other.seq_id;
134-
}
135-
};
136-
137114
//
138115
// llama_kv_cache_unified
139116
// ring-buffer of cached KV data
@@ -143,6 +120,25 @@ struct llama_kv_cell {
143120
// TODO: add notion of max sequences
144121
class llama_kv_cache_unified : public llama_kv_cache {
145122
public:
123+
struct kv_cell {
124+
llama_pos pos = -1;
125+
llama_pos delta = 0;
126+
127+
std::set<llama_seq_id> seq_id;
128+
129+
bool has_seq_id(const llama_seq_id & id) const {
130+
return seq_id.find(id) != seq_id.end();
131+
}
132+
133+
bool is_empty() const {
134+
return seq_id.empty();
135+
}
136+
137+
bool is_same_seq(const kv_cell & other) const {
138+
return seq_id == other.seq_id;
139+
}
140+
};
141+
146142
llama_kv_cache_unified(
147143
const llama_hparams & hparams,
148144
callbacks cbs,
@@ -251,7 +247,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
251247
// required padding
252248
uint32_t padding = 1;
253249

254-
std::vector<llama_kv_cell> cells;
250+
std::vector<kv_cell> cells;
255251

256252
std::vector<ggml_tensor *> k_l; // per layer
257253
std::vector<ggml_tensor *> v_l;
@@ -294,6 +290,26 @@ class llama_kv_cache_unified : public llama_kv_cache {
294290

295291
class llama_kv_cache_recurrent : public llama_kv_cache {
296292
public:
293+
struct kv_cell {
294+
llama_pos pos = -1;
295+
int32_t src = -1; // used by recurrent state models to copy states
296+
int32_t tail = -1;
297+
298+
std::set<llama_seq_id> seq_id;
299+
300+
bool has_seq_id(const llama_seq_id & id) const {
301+
return seq_id.find(id) != seq_id.end();
302+
}
303+
304+
bool is_empty() const {
305+
return seq_id.empty();
306+
}
307+
308+
bool is_same_seq(const kv_cell & other) const {
309+
return seq_id == other.seq_id;
310+
}
311+
};
312+
297313
llama_kv_cache_recurrent(
298314
const llama_hparams & hparams,
299315
callbacks cbs,
@@ -384,7 +400,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
384400
// computed before each graph build
385401
uint32_t n = 0;
386402

387-
std::vector<llama_kv_cell> cells;
403+
std::vector<kv_cell> cells;
388404

389405
std::vector<ggml_tensor *> k_l; // per layer
390406
std::vector<ggml_tensor *> v_l;

0 commit comments

Comments
 (0)