@@ -1280,6 +1280,7 @@ struct llama_kv_cache {
1280
1280
// cannot be freely changed after a slot has been allocated.
1281
1281
uint32_t head = 0 ;
1282
1282
uint32_t size = 0 ;
1283
+ uint32_t used = 0 ; // used cells (i.e. at least one seq_id)
1283
1284
1284
1285
// computed before each graph build
1285
1286
uint32_t n = 0 ;
@@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(
1504
1505
1505
1506
cache.head = 0 ;
1506
1507
cache.size = n_ctx;
1508
+ cache.used = 0 ;
1507
1509
1508
1510
cache.cells .clear ();
1509
1511
cache.cells .resize (n_ctx);
@@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
1605
1607
}
1606
1608
}
1607
1609
1610
+ cache.used += n_tokens;
1611
+
1608
1612
return true ;
1609
1613
}
1610
1614
@@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
1625
1629
cache.cells [i].seq_id .clear ();
1626
1630
}
1627
1631
cache.head = 0 ;
1632
+ cache.used = 0 ;
1628
1633
}
1629
1634
1630
1635
static void llama_kv_cache_seq_rm (
@@ -1647,14 +1652,17 @@ static void llama_kv_cache_seq_rm(
1647
1652
continue ;
1648
1653
}
1649
1654
if (cache.cells [i].seq_id .empty ()) {
1655
+ // keep count of the number of used cells
1656
+ if (cache.cells [i].pos >= 0 ) cache.used --;
1657
+
1650
1658
cache.cells [i].pos = -1 ;
1651
1659
if (new_head == cache.size ) new_head = i;
1652
1660
}
1653
1661
}
1654
1662
}
1655
1663
1656
1664
// If we freed up a slot, set head to it so searching can start there.
1657
- if (new_head != cache.size ) cache.head = new_head;
1665
+ if (new_head != cache.size && new_head < cache. head ) cache.head = new_head;
1658
1666
}
1659
1667
1660
1668
static void llama_kv_cache_seq_cp (
@@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
1680
1688
1681
1689
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1682
1690
if (!cache.cells [i].has_seq_id (seq_id)) {
1691
+ if (cache.cells [i].pos >= 0 ) cache.used --;
1683
1692
cache.cells [i].pos = -1 ;
1684
1693
cache.cells [i].seq_id .clear ();
1685
1694
if (new_head == cache.size ) new_head = i;
@@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
1690
1699
}
1691
1700
1692
1701
// If we freed up a slot, set head to it so searching can start there.
1693
- if (new_head != cache.size ) cache.head = new_head;
1702
+ if (new_head != cache.size && new_head < cache. head ) cache.head = new_head;
1694
1703
}
1695
1704
1696
1705
static void llama_kv_cache_seq_shift (
@@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
1711
1720
cache.cells [i].delta += delta;
1712
1721
1713
1722
if (cache.cells [i].pos < 0 ) {
1723
+ if (!cache.cells [i].seq_id .empty ()) cache.used --;
1714
1724
cache.cells [i].pos = -1 ;
1715
1725
cache.cells [i].seq_id .clear ();
1716
1726
if (new_head == cache.size ) new_head = i;
@@ -5469,6 +5479,12 @@ static int llama_decode_internal(
5469
5479
batch.seq_id = seq_id_arr.data ();
5470
5480
}
5471
5481
5482
+ // if we have enough unused cells before the current head ->
5483
+ // better to start searching from the beginning of the cache, hoping to fill it
5484
+ if (kv_self.head > kv_self.used + 2 *n_tokens) {
5485
+ kv_self.head = 0 ;
5486
+ }
5487
+
5472
5488
if (!llama_kv_cache_find_slot (kv_self, batch)) {
5473
5489
return 1 ;
5474
5490
}
@@ -5479,7 +5495,7 @@ static int llama_decode_internal(
5479
5495
// kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
5480
5496
kv_self.n = std::min ((int32_t ) cparams.n_ctx , std::max (32 , llama_kv_cache_cell_max (kv_self)));
5481
5497
5482
- // printf("kv_self.n = %d \n", kv_self.n);
5498
+ // printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d \n", kv_self.n, kv_self.used, kv_self.head );
5483
5499
5484
5500
ggml_allocr_reset (lctx.alloc );
5485
5501
@@ -8789,8 +8805,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
8789
8805
}
8790
8806
}
8791
8807
8808
+ struct llama_kv_cache_view llama_kv_cache_view_init (const struct llama_context * ctx, int32_t n_max_seq) {
8809
+ struct llama_kv_cache_view result = {
8810
+ /* .n_cells = */ 0 ,
8811
+ /* .n_max_seq = */ n_max_seq,
8812
+ /* .token_count = */ 0 ,
8813
+ /* .used_cells = */ llama_get_kv_cache_used_cells (ctx),
8814
+ /* .max_contiguous = */ 0 ,
8815
+ /* .max_contiguous_idx = */ -1 ,
8816
+ /* .cells = */ nullptr ,
8817
+ /* .cells_sequences = */ nullptr ,
8818
+ };
8819
+ return result;
8820
+ }
8821
+
8822
+ void llama_kv_cache_view_free (struct llama_kv_cache_view * view) {
8823
+ if (view->cells != nullptr ) {
8824
+ free (view->cells );
8825
+ view->cells = nullptr ;
8826
+ }
8827
+ if (view->cells_sequences != nullptr ) {
8828
+ free (view->cells_sequences );
8829
+ view->cells_sequences = nullptr ;
8830
+ }
8831
+ }
8832
+
8833
+ void llama_kv_cache_view_update (const struct llama_context * ctx, struct llama_kv_cache_view * view) {
8834
+ if (uint32_t (view->n_cells ) < ctx->kv_self .size || view->cells == nullptr ) {
8835
+ view->n_cells = int32_t (ctx->kv_self .size );
8836
+ void * p = realloc (view->cells , sizeof (struct llama_kv_cache_view_cell ) * view->n_cells );
8837
+ GGML_ASSERT (p != nullptr && " Failed to alloc kv_cache_view cells" );
8838
+ view->cells = (struct llama_kv_cache_view_cell *)p;
8839
+ p = realloc (view->cells_sequences , sizeof (llama_seq_id) * view->n_max_seq * view->n_cells );
8840
+ GGML_ASSERT (p != nullptr && " Failed to alloc kv_cache_view cells sequences" );
8841
+ view->cells_sequences = (llama_seq_id *)p;
8842
+ }
8843
+
8844
+ const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self .cells ;
8845
+ llama_kv_cache_view_cell * c_curr = view->cells ;
8846
+ llama_seq_id * cs_curr = view->cells_sequences ;
8847
+ int32_t used_cells = 0 ;
8848
+ int32_t token_count = 0 ;
8849
+ int32_t curr_contig_idx = -1 ;
8850
+ uint32_t max_contig = 0 ;
8851
+ int32_t max_contig_idx = -1 ;
8852
+
8853
+ for (int32_t i = 0 ; i < int32_t (ctx->kv_self .size ); i++, c_curr++, cs_curr += view->n_max_seq ) {
8854
+ const size_t curr_size = kv_cells[i].seq_id .size ();
8855
+ token_count += curr_size;
8856
+ c_curr->pos = kv_cells[i].pos + kv_cells[i].delta ;
8857
+
8858
+ if (curr_size > 0 ) {
8859
+ if (curr_contig_idx >= 0 && uint32_t (i - curr_contig_idx) > max_contig) {
8860
+ max_contig = i - curr_contig_idx;
8861
+ max_contig_idx = curr_contig_idx;
8862
+ }
8863
+ curr_contig_idx = -1 ;
8864
+ } else if (curr_contig_idx < 0 ) {
8865
+ curr_contig_idx = i;
8866
+ }
8867
+
8868
+ int seq_idx = 0 ;
8869
+ for (const llama_seq_id it : kv_cells[i].seq_id ) {
8870
+ if (seq_idx >= view->n_max_seq ) {
8871
+ break ;
8872
+ }
8873
+ cs_curr[seq_idx] = it;
8874
+ seq_idx++;
8875
+ }
8876
+ if (seq_idx != 0 ) {
8877
+ used_cells++;
8878
+ }
8879
+ for (; seq_idx < view->n_max_seq ; seq_idx++) {
8880
+ cs_curr[seq_idx] = -1 ;
8881
+ }
8882
+ }
8883
+ if (curr_contig_idx >= 0 && kv_cells.size () - curr_contig_idx > max_contig) {
8884
+ max_contig_idx = curr_contig_idx;
8885
+ max_contig = kv_cells.size () - curr_contig_idx;
8886
+ }
8887
+ view->max_contiguous = max_contig;
8888
+ view->max_contiguous_idx = max_contig_idx;
8889
+ view->token_count = token_count;
8890
+ view->used_cells = used_cells;
8891
+ if (uint32_t (used_cells) != ctx->kv_self .used ) {
8892
+ LLAMA_LOG_ERROR (" %s: used cells mismatch. kv_cache says %d but we calculated %d\n " ,
8893
+ __func__, ctx->kv_self .used , used_cells);
8894
+ }
8895
+ }
8896
+
8792
8897
int llama_get_kv_cache_token_count (const struct llama_context * ctx) {
8793
- return ctx->kv_self .head ;
8898
+ int result = 0 ;
8899
+
8900
+ for (uint32_t i = 0 ; i < ctx->kv_self .size ; i++) {
8901
+ result += ctx->kv_self .cells [i].seq_id .size ();
8902
+ }
8903
+
8904
+ return result;
8905
+ }
8906
+
8907
+ int llama_get_kv_cache_used_cells (const struct llama_context * ctx) {
8908
+ return ctx->kv_self .used ;
8794
8909
}
8795
8910
8796
8911
void llama_kv_cache_clear (struct llama_context * ctx) {
@@ -8960,10 +9075,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
8960
9075
const size_t kv_buf_size = kv_self.buf .size ;
8961
9076
const uint32_t kv_head = kv_self.head ;
8962
9077
const uint32_t kv_size = kv_self.size ;
9078
+ const uint32_t kv_used = kv_self.used ;
8963
9079
8964
9080
data_ctx->write (&kv_buf_size, sizeof (kv_buf_size));
8965
9081
data_ctx->write (&kv_head, sizeof (kv_head));
8966
9082
data_ctx->write (&kv_size, sizeof (kv_size));
9083
+ data_ctx->write (&kv_used, sizeof (kv_used));
8967
9084
8968
9085
if (kv_buf_size) {
8969
9086
const size_t elt_size = ggml_element_size (kv_self.k );
@@ -9086,10 +9203,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
9086
9203
size_t kv_buf_size;
9087
9204
uint32_t kv_head;
9088
9205
uint32_t kv_size;
9206
+ uint32_t kv_used;
9089
9207
9090
9208
memcpy (&kv_buf_size, inp, sizeof (kv_buf_size)); inp += sizeof (kv_buf_size);
9091
9209
memcpy (&kv_head, inp, sizeof (kv_head)); inp += sizeof (kv_head);
9092
9210
memcpy (&kv_size, inp, sizeof (kv_size)); inp += sizeof (kv_size);
9211
+ memcpy (&kv_used, inp, sizeof (kv_used)); inp += sizeof (kv_used);
9093
9212
9094
9213
if (kv_buf_size) {
9095
9214
GGML_ASSERT (kv_self.buf .size == kv_buf_size);
@@ -9124,6 +9243,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
9124
9243
9125
9244
ctx->kv_self .head = kv_head;
9126
9245
ctx->kv_self .size = kv_size;
9246
+ ctx->kv_self .used = kv_used;
9127
9247
9128
9248
ctx->kv_self .cells .resize (kv_size);
9129
9249
0 commit comments