@@ -2337,17 +2337,17 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
2337
2337
void llama_kv_cache_recurrent::set_full () {
2338
2338
n = size;
2339
2339
head = 0 ;
2340
+ rs_z = 0 ;
2340
2341
}
2341
2342
2342
2343
bool llama_kv_cache_recurrent::find_slot (const llama_ubatch & ubatch) {
2343
- const uint32_t n_tokens = ubatch.n_tokens ;
2344
- const uint32_t n_seqs = ubatch.n_seqs ;
2344
+ const uint32_t n_seqs = ubatch.n_seqs ;
2345
2345
2346
2346
const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
2347
2347
2348
2348
// if we have enough unused cells before the current head ->
2349
2349
// better to start searching from the beginning of the cache, hoping to fill it
2350
- if (head > used + 2 *n_tokens ) {
2350
+ if (head > used + 2 *n_seqs ) {
2351
2351
head = 0 ;
2352
2352
}
2353
2353
@@ -2443,16 +2443,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2443
2443
empty_cell.src = orig_cell.src ;
2444
2444
orig_cell.seq_id .erase (seq_id);
2445
2445
empty_cell.seq_id .insert (seq_id); // will be overwritten
2446
+ GGML_ASSERT (!orig_cell.is_empty ()); // has at least one remaining seq_id
2446
2447
}
2447
2448
seq_meta.tail = next_empty_cell;
2448
2449
// find next empty cell
2449
2450
if (s + 1 < n_seqs) {
2450
- next_empty_cell += 1 ;
2451
2451
for (uint32_t i = 0 ; i < size; ++i) {
2452
+ next_empty_cell += 1 ;
2452
2453
if (next_empty_cell >= size) { next_empty_cell -= size; }
2453
2454
kv_cell & cell = cells[next_empty_cell];
2454
2455
if (cell.is_empty ()) { break ; }
2455
- next_empty_cell += 1 ;
2456
2456
}
2457
2457
}
2458
2458
}
@@ -2472,12 +2472,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2472
2472
std::swap (dst_cell.src , src_cell.src );
2473
2473
std::swap (dst_cell.seq_id , src_cell.seq_id );
2474
2474
2475
- // swap tails (assuming they NEVER overlap)
2476
- for (const llama_seq_id seq_id : src_cell.seq_id ) {
2477
- cells[seq_id].tail = src_id;
2478
- }
2479
- for (const llama_seq_id seq_id : dst_cell.seq_id ) {
2480
- cells[seq_id].tail = dst_id;
2475
+ // swap tails
2476
+ for (uint32_t i = 0 ; i < size; ++i) {
2477
+ int32_t & tail = cells[i].tail ;
2478
+ if (tail == src_id) {
2479
+ tail = dst_id;
2480
+ } else if (tail == dst_id) {
2481
+ tail = src_id;
2482
+ }
2481
2483
}
2482
2484
}
2483
2485
}
@@ -2506,13 +2508,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2506
2508
// Find first to-be-cleared cell
2507
2509
rs_z = -1 ;
2508
2510
for (int i = min; i <= max; ++i) {
2509
- if (rs_z < 0 && cells[i].src == -1 ) {
2510
- rs_z = i;
2511
+ if (cells[i].src == -1 ) {
2512
+ if (rs_z < 0 ) {
2513
+ rs_z = i;
2514
+ }
2515
+
2516
+ cells[i].src0 = rs_z;
2517
+ } else {
2518
+ // Stage the source ids for all used cells to allow correct seq_* behavior
2519
+ // and still make these values available when setting the inputs
2520
+ cells[i].src0 = cells[i].src ;
2511
2521
}
2512
- // Stage the source ids for all used cells to allow correct seq_* behavior
2513
- // and still make these values available when setting the inputs
2514
- cells[i].src0 = cells[i].src ;
2515
- cells[i].src = i;
2522
+ cells[i].src = i; // avoid moving or clearing twice
2516
2523
}
2517
2524
2518
2525
// allow getting the range of used cells, from head to head + n
0 commit comments