@@ -464,8 +464,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
464
464
head_cur = 0 ;
465
465
}
466
466
467
- // otherwise, one cell per token.
468
-
469
467
if (n_tokens > cells.size ()) {
470
468
LLAMA_LOG_ERROR (" %s: n_tokens = %d > size = %u\n " , __func__, n_tokens, cells.size ());
471
469
return -1 ;
@@ -2344,21 +2342,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
2344
2342
2345
2343
bool success = true ;
2346
2344
2347
- // TODO: here we have to verify that all ubatches can fit in the cells
2348
- // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
2349
- // during the compute of each ubatch. to reproduce, uncomment the following loop and run:
2350
- //
2351
- // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
2352
- //
2353
- // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
2354
- //
2355
- GGML_UNUSED (ubatches);
2356
- // for (const auto & ubatch : ubatches) {
2357
- // if (!find_slot(ubatch)) {
2358
- // success = false;
2359
- // break;
2360
- // }
2361
- // }
2345
+ for (const auto & ubatch : ubatches) {
2346
+ if (!find_slot (ubatch)) {
2347
+ success = false ;
2348
+ break ;
2349
+ }
2350
+ }
2362
2351
2363
2352
// restore the original state
2364
2353
cells = std::move (org_cells);
@@ -2380,14 +2369,13 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
2380
2369
}
2381
2370
2382
2371
bool llama_kv_cache_recurrent::find_slot (const llama_ubatch & ubatch) {
2383
- const uint32_t n_tokens = ubatch.n_tokens ;
2384
- const uint32_t n_seqs = ubatch.n_seqs ;
2372
+ const uint32_t n_seqs = ubatch.n_seqs ;
2385
2373
2386
2374
const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
2387
2375
2388
2376
// if we have enough unused cells before the current head ->
2389
2377
// better to start searching from the beginning of the cache, hoping to fill it
2390
- if (head > used + 2 *n_tokens ) {
2378
+ if (head > used + 2 *n_seqs ) {
2391
2379
head = 0 ;
2392
2380
}
2393
2381
@@ -2483,16 +2471,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2483
2471
empty_cell.src = orig_cell.src ;
2484
2472
orig_cell.seq_id .erase (seq_id);
2485
2473
empty_cell.seq_id .insert (seq_id); // will be overwritten
2474
+ GGML_ASSERT (!orig_cell.is_empty ()); // has at least one remaining seq_id
2486
2475
}
2487
2476
seq_meta.tail = next_empty_cell;
2488
2477
// find next empty cell
2489
2478
if (s + 1 < n_seqs) {
2490
- next_empty_cell += 1 ;
2491
2479
for (uint32_t i = 0 ; i < size; ++i) {
2480
+ next_empty_cell += 1 ;
2492
2481
if (next_empty_cell >= size) { next_empty_cell -= size; }
2493
2482
kv_cell & cell = cells[next_empty_cell];
2494
2483
if (cell.is_empty ()) { break ; }
2495
- next_empty_cell += 1 ;
2496
2484
}
2497
2485
}
2498
2486
}
@@ -2502,8 +2490,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2502
2490
2503
2491
// gather and re-order
2504
2492
for (uint32_t s = 0 ; s < n_seqs; ++s) {
2505
- int32_t dst_id = s + min;
2506
- int32_t src_id = cells[ubatch.seq_id [s][0 ]].tail ;
2493
+ const int32_t dst_id = s + min;
2494
+ const int32_t src_id = cells[ubatch.seq_id [s][0 ]].tail ;
2507
2495
if (dst_id != src_id) {
2508
2496
kv_cell & dst_cell = cells[dst_id];
2509
2497
kv_cell & src_cell = cells[src_id];
@@ -2512,20 +2500,22 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2512
2500
std::swap (dst_cell.src , src_cell.src );
2513
2501
std::swap (dst_cell.seq_id , src_cell.seq_id );
2514
2502
2515
- // swap tails (assuming they NEVER overlap)
2516
- for (const llama_seq_id seq_id : src_cell.seq_id ) {
2517
- cells[seq_id].tail = src_id;
2518
- }
2519
- for (const llama_seq_id seq_id : dst_cell.seq_id ) {
2520
- cells[seq_id].tail = dst_id;
2503
+ // swap tails
2504
+ for (uint32_t i = 0 ; i < size; ++i) {
2505
+ int32_t & tail = cells[i].tail ;
2506
+ if (tail == src_id) {
2507
+ tail = dst_id;
2508
+ } else if (tail == dst_id) {
2509
+ tail = src_id;
2510
+ }
2521
2511
}
2522
2512
}
2523
2513
}
2524
2514
2525
2515
// update the pos of the used seqs
2526
2516
for (uint32_t s = 0 ; s < n_seqs; ++s) {
2527
2517
const llama_pos last_pos = ubatch.pos [n_seq_tokens * s + n_seq_tokens - 1 ];
2528
- int32_t cell_id = s + min;
2518
+ const int32_t cell_id = s + min;
2529
2519
kv_cell & cell = cells[cell_id];
2530
2520
2531
2521
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@@ -2543,6 +2533,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2543
2533
}
2544
2534
}
2545
2535
2536
+ // Find first cell without src refs, to use as the zero-ed state
2537
+ {
2538
+ // TODO: bake-in src refcounts in the cell metadata
2539
+ std::vector<int32_t > refcounts (size, 0 );
2540
+ for (size_t i = 0 ; i < size; ++i) {
2541
+ const int32_t src = cells[i].src ;
2542
+ if (src >= 0 ) {
2543
+ refcounts[src] += 1 ;
2544
+ }
2545
+ }
2546
+
2547
+ rs_z = -1 ;
2548
+ for (int i = min; i <= max; ++i) {
2549
+ if (refcounts[i] == 0 ) {
2550
+ rs_z = i;
2551
+ break ;
2552
+ }
2553
+ }
2554
+
2555
+ for (int i = min; i <= max; ++i) {
2556
+ if (cells[i].src < 0 ) {
2557
+ GGML_ASSERT (rs_z >= 0 );
2558
+ cells[i].src0 = rs_z;
2559
+ } else {
2560
+ // Stage the source ids for all used cells to allow correct seq_* behavior
2561
+ // and still make these values available when setting the inputs
2562
+ cells[i].src0 = cells[i].src ;
2563
+ }
2564
+ cells[i].src = i; // avoid moving or clearing twice
2565
+ }
2566
+ }
2567
+
2546
2568
// allow getting the range of used cells, from head to head + n
2547
2569
head = min;
2548
2570
n = max - min + 1 ;
@@ -2554,47 +2576,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2554
2576
}
2555
2577
2556
2578
bool llama_kv_cache_recurrent::get_can_shift () const {
2557
- return false ;
2558
- }
2559
-
2560
- int32_t llama_kv_cache_recurrent::s_copy (int i) const {
2561
- const uint32_t cell_id = i + head;
2562
-
2563
- // ////////////////////////////////////////////
2564
- // TODO: this should not mutate the KV cache !
2565
- kv_cell & cell = const_cast <kv_cell &>(cells[cell_id]);
2566
-
2567
- // prevent out-of-bound sources
2568
- if (cell.src < 0 || (uint32_t ) cell.src >= size) {
2569
- cell.src = cell_id;
2570
- }
2571
-
2572
- int32_t res = cell.src ;
2573
-
2574
- // TODO: do not mutate the KV cache
2575
- // ensure copy only happens once
2576
- if (cell.src != (int32_t ) cell_id) {
2577
- cell.src = cell_id;
2578
- }
2579
-
2580
- return res;
2581
- }
2582
-
2583
- float llama_kv_cache_recurrent::s_mask (int i) const {
2584
- const uint32_t cell_id = i + head;
2585
-
2586
- // ////////////////////////////////////////////
2587
- // TODO: this should not mutate the KV cache !
2588
- kv_cell & cell = const_cast <kv_cell &>(cells[cell_id]);
2589
-
2590
- float res = (float ) (cell.src >= 0 );
2591
-
2592
- // only clear once
2593
- if (cell.src < 0 ) {
2594
- cell.src = cell_id;
2595
- }
2596
-
2597
- return res;
2579
+ // shifting the pos is trivial for recurrent models
2580
+ return true ;
2598
2581
}
2599
2582
2600
2583
size_t llama_kv_cache_recurrent::total_size () const {
@@ -3060,6 +3043,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
3060
3043
return is_full ? 0 : kv->head ;
3061
3044
}
3062
3045
3046
+ int32_t llama_kv_cache_recurrent_state::get_rs_z () const {
3047
+ return is_full ? 0 : kv->rs_z ;
3048
+ }
3049
+
3063
3050
uint32_t llama_kv_cache_recurrent_state::get_size () const {
3064
3051
return kv->size ;
3065
3052
}
@@ -3073,9 +3060,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
3073
3060
}
3074
3061
3075
3062
int32_t llama_kv_cache_recurrent_state::s_copy (int i) const {
3076
- return kv->s_copy (i);
3077
- }
3078
-
3079
- float llama_kv_cache_recurrent_state::s_mask (int i) const {
3080
- return kv->s_mask (i);
3063
+ return kv->cells [i + kv->head ].src0 ;
3081
3064
}
0 commit comments