@@ -1737,9 +1737,6 @@ struct llama_kv_cache {
1737
1737
ggml_type type_k = GGML_TYPE_F16;
1738
1738
ggml_type type_v = GGML_TYPE_F16;
1739
1739
1740
- // if non-negative, compress data on next update
1741
- llama_pos compress_delta = -1;
1742
-
1743
1740
std::vector<llama_kv_cell> cells;
1744
1741
1745
1742
std::vector<struct ggml_tensor *> k_l; // per layer
@@ -2275,10 +2272,6 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
2275
2272
return result;
2276
2273
}
2277
2274
2278
- static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos delta) {
2279
- cache.compress_delta = delta;
2280
- }
2281
-
2282
2275
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
2283
2276
cache.do_defrag = true;
2284
2277
}
@@ -8034,240 +8027,6 @@ static int llama_decode_internal(
8034
8027
return 0;
8035
8028
}
8036
8029
8037
- // summary:
8038
- //
8039
- // - determine which KV cell pairs (i0, i1) to merge:
8040
- //
8041
- // abs(cell[i0].pos - cell[i1].pos) <= compress_delta
8042
- //
8043
- // - move the KV cache to the host memory for easier manipulation
8044
- // - processing is done layer-by-layer
8045
- // - convert the KV data to F32
8046
- // - merge the KV data (different ways to merge)
8047
- // - convert the KV data back to the original type
8048
- // - move the KV cache back to the device memory
8049
- // - update the KV cache metadata
8050
- //
8051
- // as a side effect, the new KV cache is defragmented
8052
- //
8053
- static void llama_kv_cache_compress_internal(struct llama_context & lctx) {
8054
- auto & kv_self = lctx.kv_self;
8055
-
8056
- const auto & hparams = lctx.model.hparams;
8057
-
8058
- const uint32_t n_layer = hparams.n_layer;
8059
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
8060
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
8061
- const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
8062
- const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
8063
- const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
8064
- const uint32_t kv_size = kv_self.size;
8065
-
8066
- const int64_t t_start = ggml_time_us();
8067
-
8068
- std::vector<uint8_t> buf_q;
8069
-
8070
- std::vector<float> buf_src_f32;
8071
- std::vector<float> buf_dst_f32;
8072
-
8073
- struct c_pair { uint32_t i0, i1; };
8074
- struct c_info { bool merged; uint32_t id, cnt, r; };
8075
-
8076
- std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
8077
-
8078
- // the destination cell in the new KV cache
8079
- uint32_t id = 0;
8080
-
8081
- // number of pairs merged
8082
- uint32_t n_merges = 0;
8083
-
8084
- // determine which KV cells to merge
8085
- for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
8086
- const auto & cell0 = kv_self.cells[i0];
8087
-
8088
- if (!cell0.is_empty() && !infos[i0].merged) {
8089
- infos[i0] = { true, id, 0, 0 };
8090
- infos[id].cnt = 1;
8091
-
8092
- const llama_pos p0 = cell0.pos;
8093
-
8094
- for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
8095
- const auto & cell1 = kv_self.cells[i1];
8096
-
8097
- if (i0 != i1 && cell0.is_same_seq(cell1)) {
8098
- const llama_pos p1 = cell1.pos;
8099
-
8100
- if (std::abs(p0 - p1) <= kv_self.compress_delta) {
8101
- infos[i1] = { true, id, 0, 0 };
8102
- infos[id].cnt++;
8103
- n_merges++;
8104
- }
8105
- }
8106
- }
8107
-
8108
- if (i0 != id) {
8109
- kv_self.cells[id] = cell0;
8110
- }
8111
-
8112
- id++;
8113
- }
8114
- }
8115
-
8116
- kv_self.head = id;
8117
- kv_self.used = id;
8118
-
8119
- for (uint32_t i = id; i < kv_size; ++i) {
8120
- kv_self.cells[i] = llama_kv_cell();
8121
- }
8122
-
8123
- LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
8124
-
8125
- ggml_type_traits_t tt_k;
8126
- ggml_type_traits_t tt_v;
8127
-
8128
- tt_k = ggml_internal_get_type_traits(kv_self.type_k);
8129
- tt_v = ggml_internal_get_type_traits(kv_self.type_v);
8130
-
8131
- for (uint32_t il = 0; il < n_layer; ++il) {
8132
- for (uint32_t i = 0; i < kv_size; ++i) {
8133
- infos[i].r = 0;
8134
- }
8135
-
8136
- // update keys
8137
- {
8138
- const int64_t ne = n_embd_k_gqa*kv_size;
8139
-
8140
- const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
8141
-
8142
- buf_q.resize(k_size);
8143
-
8144
- buf_src_f32.resize(ne);
8145
- buf_dst_f32.resize(ne);
8146
-
8147
- ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
8148
-
8149
- tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
8150
-
8151
- std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
8152
-
8153
- for (uint32_t i = 0; i < kv_size; ++i) {
8154
- if (!infos[i].merged) {
8155
- continue;
8156
- }
8157
-
8158
- const uint32_t id = infos[i].id;
8159
-
8160
- #if 1
8161
- // merge using averaging
8162
- {
8163
- const float scale = 1.0f/float(infos[id].cnt);
8164
-
8165
- const int64_t os = i*n_embd_k_gqa;
8166
- const int64_t od = id*n_embd_k_gqa;
8167
-
8168
- for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
8169
- buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
8170
- }
8171
- }
8172
- #else
8173
- // merge separate heads
8174
- {
8175
- for (uint32_t h = 0; h < n_head_kv; ++h) {
8176
- if ((h + il) % infos[id].cnt != infos[id].r) {
8177
- continue;
8178
- }
8179
-
8180
- const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
8181
- const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
8182
-
8183
- for (uint32_t j = 0; j < n_embd_head_k; ++j) {
8184
- buf_dst_f32[od + j] = buf_src_f32[os + j];
8185
- }
8186
- }
8187
- }
8188
-
8189
- infos[id].r++;
8190
- #endif
8191
- }
8192
-
8193
- tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
8194
-
8195
- ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
8196
- }
8197
-
8198
- for (uint32_t i = 0; i < kv_size; ++i) {
8199
- infos[i].r = 0;
8200
- }
8201
-
8202
- // update values (note: they are transposed)
8203
- {
8204
- const int64_t ne = n_embd_v_gqa*kv_size;
8205
-
8206
- const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
8207
-
8208
- buf_q.resize(v_size);
8209
-
8210
- buf_src_f32.resize(ne);
8211
- buf_dst_f32.resize(ne);
8212
-
8213
- ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
8214
-
8215
- tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
8216
-
8217
- std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
8218
-
8219
- for (uint32_t i = 0; i < kv_size; ++i) {
8220
- if (!infos[i].merged) {
8221
- continue;
8222
- }
8223
-
8224
- const uint32_t id = infos[i].id;
8225
-
8226
- #if 1
8227
- // merge using averaging
8228
- {
8229
- const float scale = 1.0f/float(infos[id].cnt);
8230
- //printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
8231
-
8232
- const int64_t os = i;
8233
- const int64_t od = id;
8234
-
8235
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
8236
- buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
8237
- }
8238
- }
8239
- #else
8240
- // merge separate heads
8241
- {
8242
- for (uint32_t h = 0; h < n_head_kv; ++h) {
8243
- if ((h + il) % infos[id].cnt != infos[id].r) {
8244
- continue;
8245
- }
8246
-
8247
- const int64_t os = i;
8248
- const int64_t od = id;
8249
-
8250
- for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
8251
- buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
8252
- }
8253
- }
8254
- }
8255
-
8256
- infos[id].r++;
8257
- #endif
8258
- }
8259
-
8260
- tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
8261
-
8262
- ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
8263
- }
8264
- }
8265
-
8266
- const int64_t t_end = ggml_time_us();
8267
-
8268
- LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
8269
- }
8270
-
8271
8030
// copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache
8272
8031
// this way we eliminate any empty segments that may have been left by previous KV cache operations
8273
8032
//
@@ -8412,14 +8171,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
8412
8171
}
8413
8172
}
8414
8173
8415
- // compress the KV cache data if needed
8416
- if (lctx.kv_self.compress_delta >= 0) {
8417
- llama_kv_cache_compress_internal(lctx);
8418
-
8419
- lctx.kv_self.compress_delta = -1;
8420
- lctx.kv_self.do_defrag = false;
8421
- }
8422
-
8423
8174
// defragment the KV cache if needed
8424
8175
if (lctx.kv_self.do_defrag) {
8425
8176
llama_kv_cache_defrag_internal(lctx);
@@ -12496,10 +12247,6 @@ llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id se
12496
12247
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
12497
12248
}
12498
12249
12499
- void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
12500
- llama_kv_cache_compress(ctx->kv_self, delta);
12501
- }
12502
-
12503
12250
void llama_kv_cache_defrag(struct llama_context * ctx) {
12504
12251
llama_kv_cache_defrag(ctx->kv_self);
12505
12252
}
0 commit comments