@@ -8028,7 +8028,7 @@ static int llama_decode_internal(
8028
8028
}
8029
8029
8030
8030
// copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache
8031
- // this way we eliminate any empty segments that may have been left by previous KV cache operations
8031
+ // this way we eliminate any empty holes that may have been left by previous KV cache operations
8032
8032
//
8033
8033
// TODO: optimizations are possible:
8034
8034
// - multiple threads
@@ -8045,36 +8045,81 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8045
8045
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
8046
8046
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
8047
8047
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
8048
+ const uint32_t n_used = kv_self.used;
8048
8049
8049
8050
const uint32_t kv_size = kv_self.size;
8050
8051
8052
+ assert(n_used <= n_kv);
8053
+
8051
8054
const int64_t t_start = ggml_time_us();
8052
8055
8053
8056
std::vector<uint8_t> buf_k;
8054
8057
std::vector<uint8_t> buf_v;
8055
8058
8056
- // the destination cell in the new KV cache
8057
- uint32_t id = 0;
8058
-
8059
8059
// number of cells moved
8060
8060
uint32_t n_moves = 0;
8061
8061
8062
8062
// determine which KV cells to move where
8063
8063
std::vector<uint32_t> ids(n_kv, n_kv);
8064
8064
8065
- for (uint32_t i0 = 0; i0 < n_kv ; ++i0) {
8065
+ for (uint32_t i0 = 0; i0 < n_used ; ++i0) {
8066
8066
const auto & cell0 = kv_self.cells[i0];
8067
8067
8068
8068
if (!cell0.is_empty()) {
8069
- ids[i0] = id;
8069
+ ids[i0] = i0;
8070
+
8071
+ continue;
8072
+ }
8073
+
8074
+ // found a hole - fill it with data from the end of the cache
8075
+
8076
+ // determine the size of the hole
8077
+ uint32_t nh = 1;
8078
+ while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
8079
+ nh++;
8080
+ }
8081
+
8082
+ // starting from the end, find nh non-empty cells
8083
+ uint32_t nf = 0;
8084
+ uint32_t is = n_kv - 1;
8085
+ for (; is > i0; --is) {
8086
+ const auto & cell1 = kv_self.cells[is];
8087
+
8088
+ if (cell1.is_empty() || ids[is] != n_kv) {
8089
+ continue;
8090
+ }
8070
8091
8071
- if (i0 != id) {
8072
- kv_self.cells[id] = cell0;
8073
- n_moves++;
8092
+ // non-empty cell which is not yet moved
8093
+ nf++;
8094
+ if (nf == nh) {
8095
+ break;
8096
+ }
8097
+ }
8098
+
8099
+ GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
8100
+
8101
+ nf = 0;
8102
+
8103
+ // go back and move the nf cells to the hole
8104
+ for (uint32_t i1 = is; i1 < n_kv; ++i1) {
8105
+ const auto & cell1 = kv_self.cells[i1];
8106
+
8107
+ if (cell1.is_empty() || ids[i1] != n_kv) {
8108
+ continue;
8074
8109
}
8075
8110
8076
- id++;
8111
+ ids[i1] = i0 + nf;
8112
+
8113
+ // move the cell meta data
8114
+ kv_self.cells[i0 + nf] = cell1;
8115
+
8116
+ n_moves++;
8117
+ nf++;
8077
8118
}
8119
+
8120
+ LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
8121
+
8122
+ i0 += nh - 1;
8078
8123
}
8079
8124
8080
8125
if (n_moves == 0) {
@@ -8083,11 +8128,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8083
8128
8084
8129
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
8085
8130
8086
- kv_self.head = id ;
8087
- kv_self.used = id ;
8131
+ kv_self.head = n_used ;
8132
+ kv_self.used = n_used ;
8088
8133
8089
8134
// zero the rest of the cells
8090
- for (uint32_t i = id ; i < n_kv; ++i) {
8135
+ for (uint32_t i = n_used ; i < n_kv; ++i) {
8091
8136
kv_self.cells[i] = llama_kv_cell();
8092
8137
}
8093
8138
0 commit comments