6
6
#include < bitset>
7
7
#include < cassert>
8
8
#include < vector>
9
+ #include < set>
9
10
10
11
// meta information about KV cells that can be part of multiple sequences at the same time
11
12
// TODO: add unit tests
@@ -18,8 +19,13 @@ class llama_kv_cells_unified {
18
19
seq[i].reset ();
19
20
}
20
21
21
- used = 0 ;
22
22
has_shift = false ;
23
+
24
+ used.clear ();
25
+
26
+ for (uint32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
27
+ seq_pos[s].clear ();
28
+ }
23
29
}
24
30
25
31
void reset_shift () {
@@ -50,7 +56,25 @@ class llama_kv_cells_unified {
50
56
}
51
57
52
58
uint32_t get_used () const {
53
- return used;
59
+ return used.size ();
60
+ }
61
+
62
+ // the index of the first cell that is used
63
+ // return 0 if no cells are used
64
+ uint32_t used_min () const {
65
+ return used.empty () ? 0 : *used.begin ();
66
+ }
67
+
68
+ // the index of the last cell that is used + 1
69
+ // return 0 if no cells are used
70
+ uint32_t used_max_p1 () const {
71
+ #if 0
72
+ if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73
+ if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74
+ if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75
+ #endif
76
+
77
+ return used.empty () ? 0 : *used.rbegin () + 1 ;
54
78
}
55
79
56
80
bool get_has_shift () const {
@@ -69,6 +93,9 @@ class llama_kv_cells_unified {
69
93
pos [isrc] = -1 ;
70
94
shift[isrc] = 0 ;
71
95
seq [isrc].reset ();
96
+
97
+ used.erase (isrc);
98
+ used.insert (idst);
72
99
}
73
100
74
101
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
@@ -95,16 +122,24 @@ class llama_kv_cells_unified {
95
122
96
123
for (uint32_t j = 0 ; j < other.pos .size (); ++j) {
97
124
if (pos[i + j] == -1 && other.pos [j] != -1 ) {
98
- used++ ;
125
+ used. insert (i + j) ;
99
126
}
100
127
101
128
if (pos[i + j] != -1 && other.pos [j] == -1 ) {
102
- used--;
129
+ used.erase (i + j);
130
+ }
131
+
132
+ if (pos[i + j] != -1 ) {
133
+ seq_pos_rm (i + j);
103
134
}
104
135
105
136
pos[i + j] = other.pos [j];
106
137
seq[i + j] = other.seq [j];
107
138
139
+ if (pos[i + j] != -1 ) {
140
+ seq_pos_add (i + j);
141
+ }
142
+
108
143
assert (shift[i + j] == 0 );
109
144
}
110
145
}
@@ -118,11 +153,12 @@ class llama_kv_cells_unified {
118
153
assert (seq_id >= 0 );
119
154
120
155
seq[i].reset (seq_id);
156
+ seq_pos[seq_id].erase (pos[i]);
121
157
122
158
if (seq[i].none ()) {
123
159
pos[i] = -1 ;
124
160
125
- used-- ;
161
+ used. erase (i) ;
126
162
127
163
return true ;
128
164
}
@@ -135,17 +171,22 @@ class llama_kv_cells_unified {
135
171
assert (i < pos.size ());
136
172
137
173
if (seq[i].test (seq_id)) {
174
+ seq_pos_rm (i);
138
175
seq[i].reset ();
176
+
139
177
seq[i].set (seq_id);
178
+ seq_pos[seq_id].insert (pos[i]);
140
179
141
180
return false ;
142
181
}
143
182
144
183
if (seq[i].any ()) {
184
+ seq_pos_rm (i);
145
185
seq[i].reset ();
186
+
146
187
pos[i] = -1 ;
147
188
148
- used-- ;
189
+ used. erase (i) ;
149
190
150
191
return true ;
151
192
}
@@ -169,6 +210,33 @@ class llama_kv_cells_unified {
169
210
assert (!seq[i].test (seq_id));
170
211
171
212
seq[i].set (seq_id);
213
+ seq_pos[seq_id].insert (pos[i]);
214
+ }
215
+
216
+ // the minimum position of sequence seq_id currently present in any of the cells
217
+ // return -1 if the sequence is not present
218
+ llama_pos seq_pos_min (llama_seq_id seq_id) const {
219
+ assert (seq_id >= 0 );
220
+ assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
221
+
222
+ if (seq_pos[seq_id].empty ()) {
223
+ return -1 ;
224
+ }
225
+
226
+ return *seq_pos[seq_id].begin ();
227
+ }
228
+
229
+ // the maximum position of sequence seq_id currently present in any of the cells
230
+ // return -1 if the sequence is not present
231
+ llama_pos seq_pos_max (llama_seq_id seq_id) const {
232
+ assert (seq_id >= 0 );
233
+ assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
234
+
235
+ if (seq_pos[seq_id].empty ()) {
236
+ return -1 ;
237
+ }
238
+
239
+ return *seq_pos[seq_id].rbegin ();
172
240
}
173
241
174
242
// note: call only if the cell is not empty
@@ -202,7 +270,8 @@ class llama_kv_cells_unified {
202
270
assert (pos[i] == -1 );
203
271
204
272
pos[i] = p;
205
- used++;
273
+
274
+ used.insert (i);
206
275
}
207
276
208
277
// pos[i] = pos[i] + d
@@ -212,16 +281,22 @@ class llama_kv_cells_unified {
212
281
assert (i < pos.size ());
213
282
assert (pos[i] != -1 );
214
283
284
+ seq_pos_rm (i);
285
+
215
286
pos[i] += d;
216
287
shift[i] += d;
217
288
289
+ seq_pos_add (i);
290
+
218
291
has_shift = true ;
219
292
220
293
if (pos[i] < 0 ) {
221
- pos[i] = -1 ;
294
+ seq_pos_rm (i);
295
+
222
296
seq[i].reset ();
297
+ pos[i] = -1 ;
223
298
224
- used-- ;
299
+ used. erase (i) ;
225
300
226
301
return true ;
227
302
}
@@ -238,17 +313,22 @@ class llama_kv_cells_unified {
238
313
239
314
const llama_pos p_old = pos[i];
240
315
316
+ seq_pos_rm (i);
317
+
241
318
pos[i] /= d;
242
319
shift[i] += p_old - pos[i];
243
320
321
+ seq_pos_add (i);
322
+
244
323
has_shift = true ;
245
324
}
246
325
247
326
private:
248
- uint32_t used = 0 ; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
249
-
250
327
bool has_shift = false ;
251
328
329
+ // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
330
+ std::set<uint32_t > used;
331
+
252
332
std::vector<llama_pos> pos;
253
333
254
334
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
@@ -268,6 +348,32 @@ class llama_kv_cells_unified {
268
348
//
269
349
std::vector<llama_pos> shift;
270
350
271
- std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
272
- };
351
+ using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
352
+
353
+ // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
354
+ std::vector<bits_t > seq;
355
+
356
+ // the set seq_pos[s] tells us which positions are currently present for sequence s
357
+ // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
358
+ std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
359
+
360
+ // helper functions for updating `seq_pos`, once cell at a time:
361
+
362
+ // remove cell i
363
+ void seq_pos_rm (uint32_t i) {
364
+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
365
+ if (seq[i].test (s)) {
366
+ seq_pos[s].erase (pos[i]);
367
+ }
368
+ }
369
+ }
273
370
371
+ // add cell i
372
+ void seq_pos_add (uint32_t i) {
373
+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
374
+ if (seq[i].test (s)) {
375
+ seq_pos[s].insert (pos[i]);
376
+ }
377
+ }
378
+ }
379
+ };
0 commit comments