Skip to content

Commit 2d7203b

Browse files
committed
llama : remove llama_kv_cache_compress
will add in a separate PR ggml-ci
1 parent 1b6aeb8 commit 2d7203b

File tree

3 files changed

+0
-262
lines changed

3 files changed

+0
-262
lines changed

examples/passkey/passkey.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ int main(int argc, char ** argv) {
148148

149149
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
150150
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
151-
llama_kv_cache_compress(ctx, 0);
152151
llama_kv_cache_update (ctx);
153152

154153
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

llama.cpp

Lines changed: 0 additions & 253 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,9 +1737,6 @@ struct llama_kv_cache {
17371737
ggml_type type_k = GGML_TYPE_F16;
17381738
ggml_type type_v = GGML_TYPE_F16;
17391739

1740-
// if non-negative, compress data on next update
1741-
llama_pos compress_delta = -1;
1742-
17431740
std::vector<llama_kv_cell> cells;
17441741

17451742
std::vector<struct ggml_tensor *> k_l; // per layer
@@ -2275,10 +2272,6 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
22752272
return result;
22762273
}
22772274

2278-
static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos delta) {
2279-
cache.compress_delta = delta;
2280-
}
2281-
22822275
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
22832276
cache.do_defrag = true;
22842277
}
@@ -8034,240 +8027,6 @@ static int llama_decode_internal(
80348027
return 0;
80358028
}
80368029

8037-
// summary:
8038-
//
8039-
// - determine which KV cell pairs (i0, i1) to merge:
8040-
//
8041-
// abs(cell[i0].pos - cell[i1].pos) <= compress_delta
8042-
//
8043-
// - move the KV cache to the host memory for easier manipulation
8044-
// - processing is done layer-by-layer
8045-
// - convert the KV data to F32
8046-
// - merge the KV data (different ways to merge)
8047-
// - convert the KV data back to the original type
8048-
// - move the KV cache back to the device memory
8049-
// - update the KV cache metadata
8050-
//
8051-
// as a side effect, the new KV cache is defragmented
8052-
//
8053-
static void llama_kv_cache_compress_internal(struct llama_context & lctx) {
8054-
auto & kv_self = lctx.kv_self;
8055-
8056-
const auto & hparams = lctx.model.hparams;
8057-
8058-
const uint32_t n_layer = hparams.n_layer;
8059-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
8060-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
8061-
const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
8062-
const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
8063-
const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
8064-
const uint32_t kv_size = kv_self.size;
8065-
8066-
const int64_t t_start = ggml_time_us();
8067-
8068-
std::vector<uint8_t> buf_q;
8069-
8070-
std::vector<float> buf_src_f32;
8071-
std::vector<float> buf_dst_f32;
8072-
8073-
struct c_pair { uint32_t i0, i1; };
8074-
struct c_info { bool merged; uint32_t id, cnt, r; };
8075-
8076-
std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
8077-
8078-
// the destination cell in the new KV cache
8079-
uint32_t id = 0;
8080-
8081-
// number of pairs merged
8082-
uint32_t n_merges = 0;
8083-
8084-
// determine which KV cells to merge
8085-
for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
8086-
const auto & cell0 = kv_self.cells[i0];
8087-
8088-
if (!cell0.is_empty() && !infos[i0].merged) {
8089-
infos[i0] = { true, id, 0, 0 };
8090-
infos[id].cnt = 1;
8091-
8092-
const llama_pos p0 = cell0.pos;
8093-
8094-
for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
8095-
const auto & cell1 = kv_self.cells[i1];
8096-
8097-
if (i0 != i1 && cell0.is_same_seq(cell1)) {
8098-
const llama_pos p1 = cell1.pos;
8099-
8100-
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
8101-
infos[i1] = { true, id, 0, 0 };
8102-
infos[id].cnt++;
8103-
n_merges++;
8104-
}
8105-
}
8106-
}
8107-
8108-
if (i0 != id) {
8109-
kv_self.cells[id] = cell0;
8110-
}
8111-
8112-
id++;
8113-
}
8114-
}
8115-
8116-
kv_self.head = id;
8117-
kv_self.used = id;
8118-
8119-
for (uint32_t i = id; i < kv_size; ++i) {
8120-
kv_self.cells[i] = llama_kv_cell();
8121-
}
8122-
8123-
LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
8124-
8125-
ggml_type_traits_t tt_k;
8126-
ggml_type_traits_t tt_v;
8127-
8128-
tt_k = ggml_internal_get_type_traits(kv_self.type_k);
8129-
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
8130-
8131-
for (uint32_t il = 0; il < n_layer; ++il) {
8132-
for (uint32_t i = 0; i < kv_size; ++i) {
8133-
infos[i].r = 0;
8134-
}
8135-
8136-
// update keys
8137-
{
8138-
const int64_t ne = n_embd_k_gqa*kv_size;
8139-
8140-
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
8141-
8142-
buf_q.resize(k_size);
8143-
8144-
buf_src_f32.resize(ne);
8145-
buf_dst_f32.resize(ne);
8146-
8147-
ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
8148-
8149-
tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
8150-
8151-
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
8152-
8153-
for (uint32_t i = 0; i < kv_size; ++i) {
8154-
if (!infos[i].merged) {
8155-
continue;
8156-
}
8157-
8158-
const uint32_t id = infos[i].id;
8159-
8160-
#if 1
8161-
// merge using averaging
8162-
{
8163-
const float scale = 1.0f/float(infos[id].cnt);
8164-
8165-
const int64_t os = i*n_embd_k_gqa;
8166-
const int64_t od = id*n_embd_k_gqa;
8167-
8168-
for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
8169-
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
8170-
}
8171-
}
8172-
#else
8173-
// merge separate heads
8174-
{
8175-
for (uint32_t h = 0; h < n_head_kv; ++h) {
8176-
if ((h + il) % infos[id].cnt != infos[id].r) {
8177-
continue;
8178-
}
8179-
8180-
const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
8181-
const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
8182-
8183-
for (uint32_t j = 0; j < n_embd_head_k; ++j) {
8184-
buf_dst_f32[od + j] = buf_src_f32[os + j];
8185-
}
8186-
}
8187-
}
8188-
8189-
infos[id].r++;
8190-
#endif
8191-
}
8192-
8193-
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
8194-
8195-
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
8196-
}
8197-
8198-
for (uint32_t i = 0; i < kv_size; ++i) {
8199-
infos[i].r = 0;
8200-
}
8201-
8202-
// update values (note: they are transposed)
8203-
{
8204-
const int64_t ne = n_embd_v_gqa*kv_size;
8205-
8206-
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
8207-
8208-
buf_q.resize(v_size);
8209-
8210-
buf_src_f32.resize(ne);
8211-
buf_dst_f32.resize(ne);
8212-
8213-
ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
8214-
8215-
tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
8216-
8217-
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
8218-
8219-
for (uint32_t i = 0; i < kv_size; ++i) {
8220-
if (!infos[i].merged) {
8221-
continue;
8222-
}
8223-
8224-
const uint32_t id = infos[i].id;
8225-
8226-
#if 1
8227-
// merge using averaging
8228-
{
8229-
const float scale = 1.0f/float(infos[id].cnt);
8230-
//printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
8231-
8232-
const int64_t os = i;
8233-
const int64_t od = id;
8234-
8235-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
8236-
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
8237-
}
8238-
}
8239-
#else
8240-
// merge separate heads
8241-
{
8242-
for (uint32_t h = 0; h < n_head_kv; ++h) {
8243-
if ((h + il) % infos[id].cnt != infos[id].r) {
8244-
continue;
8245-
}
8246-
8247-
const int64_t os = i;
8248-
const int64_t od = id;
8249-
8250-
for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
8251-
buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
8252-
}
8253-
}
8254-
}
8255-
8256-
infos[id].r++;
8257-
#endif
8258-
}
8259-
8260-
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
8261-
8262-
ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
8263-
}
8264-
}
8265-
8266-
const int64_t t_end = ggml_time_us();
8267-
8268-
LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
8269-
}
8270-
82718030
// copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache
82728031
// this way we eliminate any empty segments that may have been left by previous KV cache operations
82738032
//
@@ -8412,14 +8171,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
84128171
}
84138172
}
84148173

8415-
// compress the KV cache data if needed
8416-
if (lctx.kv_self.compress_delta >= 0) {
8417-
llama_kv_cache_compress_internal(lctx);
8418-
8419-
lctx.kv_self.compress_delta = -1;
8420-
lctx.kv_self.do_defrag = false;
8421-
}
8422-
84238174
// defragment the KV cache if needed
84248175
if (lctx.kv_self.do_defrag) {
84258176
llama_kv_cache_defrag_internal(lctx);
@@ -12496,10 +12247,6 @@ llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id se
1249612247
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
1249712248
}
1249812249

12499-
void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
12500-
llama_kv_cache_compress(ctx->kv_self, delta);
12501-
}
12502-
1250312250
void llama_kv_cache_defrag(struct llama_context * ctx) {
1250412251
llama_kv_cache_defrag(ctx->kv_self);
1250512252
}

llama.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -554,14 +554,6 @@ extern "C" {
554554
struct llama_context * ctx,
555555
llama_seq_id seq_id);
556556

557-
// [EXPERIMENTAL] Compress the data in the KV cache
558-
// This will be applied:
559-
// - lazily on next llama_decode()
560-
// - explicitly with llama_kv_cache_update()
561-
LLAMA_API void llama_kv_cache_compress(
562-
struct llama_context * ctx,
563-
llama_pos delta);
564-
565557
// Defragment the KV cache
566558
// This will be applied:
567559
// - lazily on next llama_decode()

0 commit comments

Comments
 (0)