Skip to content

Commit 8171312

Browse files
authored
kv-cells : track min/max used cells and per-sequence positions (#13808)
* kv-cells : track min/max used cells and per-sequence positions ggml-ci * kv-cells : fix pos-modification updates for seq_pos ggml-ci * kv-cells : add comments ggml-ci
1 parent f9cd683 commit 8171312

File tree

3 files changed

+123
-51
lines changed

3 files changed

+123
-51
lines changed

src/llama-kv-cache.cpp

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
286286
}
287287

288288
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
289-
llama_pos result = std::numeric_limits<llama_pos>::max();
290-
291-
for (uint32_t i = 0; i < cells.size(); ++i) {
292-
if (cells.seq_has(i, seq_id)) {
293-
result = std::min(result, cells.pos_get(i));
294-
}
295-
}
296-
297-
if (result == std::numeric_limits<llama_pos>::max()) {
298-
result = -1;
299-
}
300-
301-
return result;
289+
return cells.seq_pos_min(seq_id);
302290
}
303291

304292
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
305-
llama_pos result = -1;
306-
307-
for (uint32_t i = 0; i < cells.size(); ++i) {
308-
if (cells.seq_has(i, seq_id)) {
309-
result = std::max(result, cells.pos_get(i));
310-
}
311-
}
312-
313-
return result;
293+
return cells.seq_pos_max(seq_id);
314294
}
315295

316296
void llama_kv_cache_unified::restore() {
@@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
504484
// a heuristic, to avoid attending the full cache if it is not yet utilized
505485
// after enough generations, the benefit from this heuristic disappears
506486
// if we start defragmenting the cache, the benefit from this will be more important
507-
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
487+
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
508488

509489
#ifdef FIND_SLOT_DEBUG
510490
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1018998
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1019999
const uint32_t n_layer = layers.size();
10201000

1021-
const uint32_t n_kv = cell_max();
1001+
const uint32_t n_kv = cells.used_max_p1();
10221002
const uint32_t n_used = cells.get_used();
10231003

10241004
assert(n_used <= n_kv);
@@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
11441124
return true;
11451125
}
11461126

1147-
uint32_t llama_kv_cache_unified::cell_max() const {
1148-
for (uint32_t i = cells.size(); i > 0; --i) {
1149-
if (!cells.is_empty(i - 1)) {
1150-
return i;
1151-
}
1152-
}
1153-
1154-
return 0;
1155-
}
1156-
11571127
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
11581128
assert(p0 >= 0 && p1 >= 0);
11591129

src/llama-kv-cache.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
246246
// return true if cells have been moved
247247
bool defrag_prepare(int32_t n_max_nodes);
248248

249-
// find how many cells are currently in use
250-
// TODO: optimize
251-
uint32_t cell_max() const;
252-
253249
size_t total_size() const;
254250

255251
size_t size_k_bytes() const;

src/llama-kv-cells.h

Lines changed: 119 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <bitset>
77
#include <cassert>
88
#include <vector>
9+
#include <set>
910

1011
// meta information about KV cells that can be part of multiple sequences at the same time
1112
// TODO: add unit tests
@@ -18,8 +19,13 @@ class llama_kv_cells_unified {
1819
seq[i].reset();
1920
}
2021

21-
used = 0;
2222
has_shift = false;
23+
24+
used.clear();
25+
26+
for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
27+
seq_pos[s].clear();
28+
}
2329
}
2430

2531
void reset_shift() {
@@ -50,7 +56,25 @@ class llama_kv_cells_unified {
5056
}
5157

5258
uint32_t get_used() const {
53-
return used;
59+
return used.size();
60+
}
61+
62+
// the index of the first cell that is used
63+
// return 0 if no cells are used
64+
uint32_t used_min() const {
65+
return used.empty() ? 0 : *used.begin();
66+
}
67+
68+
// the index of the last cell that is used + 1
69+
// return 0 if no cells are used
70+
uint32_t used_max_p1() const {
71+
#if 0
72+
if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73+
if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74+
if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75+
#endif
76+
77+
return used.empty() ? 0 : *used.rbegin() + 1;
5478
}
5579

5680
bool get_has_shift() const {
@@ -69,6 +93,9 @@ class llama_kv_cells_unified {
6993
pos [isrc] = -1;
7094
shift[isrc] = 0;
7195
seq [isrc].reset();
96+
97+
used.erase (isrc);
98+
used.insert(idst);
7299
}
73100

74101
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
@@ -95,16 +122,24 @@ class llama_kv_cells_unified {
95122

96123
for (uint32_t j = 0; j < other.pos.size(); ++j) {
97124
if (pos[i + j] == -1 && other.pos[j] != -1) {
98-
used++;
125+
used.insert(i + j);
99126
}
100127

101128
if (pos[i + j] != -1 && other.pos[j] == -1) {
102-
used--;
129+
used.erase(i + j);
130+
}
131+
132+
if (pos[i + j] != -1) {
133+
seq_pos_rm(i + j);
103134
}
104135

105136
pos[i + j] = other.pos[j];
106137
seq[i + j] = other.seq[j];
107138

139+
if (pos[i + j] != -1) {
140+
seq_pos_add(i + j);
141+
}
142+
108143
assert(shift[i + j] == 0);
109144
}
110145
}
@@ -118,11 +153,12 @@ class llama_kv_cells_unified {
118153
assert(seq_id >= 0);
119154

120155
seq[i].reset(seq_id);
156+
seq_pos[seq_id].erase(pos[i]);
121157

122158
if (seq[i].none()) {
123159
pos[i] = -1;
124160

125-
used--;
161+
used.erase(i);
126162

127163
return true;
128164
}
@@ -135,17 +171,22 @@ class llama_kv_cells_unified {
135171
assert(i < pos.size());
136172

137173
if (seq[i].test(seq_id)) {
174+
seq_pos_rm(i);
138175
seq[i].reset();
176+
139177
seq[i].set(seq_id);
178+
seq_pos[seq_id].insert(pos[i]);
140179

141180
return false;
142181
}
143182

144183
if (seq[i].any()) {
184+
seq_pos_rm(i);
145185
seq[i].reset();
186+
146187
pos[i] = -1;
147188

148-
used--;
189+
used.erase(i);
149190

150191
return true;
151192
}
@@ -169,6 +210,33 @@ class llama_kv_cells_unified {
169210
assert(!seq[i].test(seq_id));
170211

171212
seq[i].set(seq_id);
213+
seq_pos[seq_id].insert(pos[i]);
214+
}
215+
216+
// the minimum position of sequence seq_id currently present in any of the cells
217+
// return -1 if the sequence is not present
218+
llama_pos seq_pos_min(llama_seq_id seq_id) const {
219+
assert(seq_id >= 0);
220+
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
221+
222+
if (seq_pos[seq_id].empty()) {
223+
return -1;
224+
}
225+
226+
return *seq_pos[seq_id].begin();
227+
}
228+
229+
// the maximum position of sequence seq_id currently present in any of the cells
230+
// return -1 if the sequence is not present
231+
llama_pos seq_pos_max(llama_seq_id seq_id) const {
232+
assert(seq_id >= 0);
233+
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
234+
235+
if (seq_pos[seq_id].empty()) {
236+
return -1;
237+
}
238+
239+
return *seq_pos[seq_id].rbegin();
172240
}
173241

174242
// note: call only if the cell is not empty
@@ -202,7 +270,8 @@ class llama_kv_cells_unified {
202270
assert(pos[i] == -1);
203271

204272
pos[i] = p;
205-
used++;
273+
274+
used.insert(i);
206275
}
207276

208277
// pos[i] = pos[i] + d
@@ -212,16 +281,22 @@ class llama_kv_cells_unified {
212281
assert(i < pos.size());
213282
assert(pos[i] != -1);
214283

284+
seq_pos_rm(i);
285+
215286
pos[i] += d;
216287
shift[i] += d;
217288

289+
seq_pos_add(i);
290+
218291
has_shift = true;
219292

220293
if (pos[i] < 0) {
221-
pos[i] = -1;
294+
seq_pos_rm(i);
295+
222296
seq[i].reset();
297+
pos[i] = -1;
223298

224-
used--;
299+
used.erase(i);
225300

226301
return true;
227302
}
@@ -238,17 +313,22 @@ class llama_kv_cells_unified {
238313

239314
const llama_pos p_old = pos[i];
240315

316+
seq_pos_rm(i);
317+
241318
pos[i] /= d;
242319
shift[i] += p_old - pos[i];
243320

321+
seq_pos_add(i);
322+
244323
has_shift = true;
245324
}
246325

247326
private:
248-
uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
249-
250327
bool has_shift = false;
251328

329+
// set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
330+
std::set<uint32_t> used;
331+
252332
std::vector<llama_pos> pos;
253333

254334
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
@@ -268,6 +348,32 @@ class llama_kv_cells_unified {
268348
//
269349
std::vector<llama_pos> shift;
270350

271-
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
272-
};
351+
using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
352+
353+
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
354+
std::vector<bits_t> seq;
355+
356+
// the set seq_pos[s] tells us which positions are currently present for sequence s
357+
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
358+
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
359+
360+
// helper functions for updating `seq_pos`, once cell at a time:
361+
362+
// remove cell i
363+
void seq_pos_rm(uint32_t i) {
364+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
365+
if (seq[i].test(s)) {
366+
seq_pos[s].erase(pos[i]);
367+
}
368+
}
369+
}
273370

371+
// add cell i
372+
void seq_pos_add(uint32_t i) {
373+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
374+
if (seq[i].test(s)) {
375+
seq_pos[s].insert(pos[i]);
376+
}
377+
}
378+
}
379+
};

0 commit comments

Comments
 (0)