Skip to content

Commit 65323bc

Browse files
committed
llama : defragment via non-overlapping moves
1 parent 2d7203b commit 65323bc

File tree

1 file changed

+58
-13
lines changed

1 file changed

+58
-13
lines changed

llama.cpp

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8028,7 +8028,7 @@ static int llama_decode_internal(
80288028
}
80298029

80308030
// copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache
8031-
// this way we eliminate any empty segments that may have been left by previous KV cache operations
8031+
// this way we eliminate any empty holes that may have been left by previous KV cache operations
80328032
//
80338033
// TODO: optimizations are possible:
80348034
// - multiple threads
@@ -8045,36 +8045,81 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
80458045
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
80468046
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
80478047
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
8048+
const uint32_t n_used = kv_self.used;
80488049

80498050
const uint32_t kv_size = kv_self.size;
80508051

8052+
assert(n_used <= n_kv);
8053+
80518054
const int64_t t_start = ggml_time_us();
80528055

80538056
std::vector<uint8_t> buf_k;
80548057
std::vector<uint8_t> buf_v;
80558058

8056-
// the destination cell in the new KV cache
8057-
uint32_t id = 0;
8058-
80598059
// number of cells moved
80608060
uint32_t n_moves = 0;
80618061

80628062
// determine which KV cells to move where
80638063
std::vector<uint32_t> ids(n_kv, n_kv);
80648064

8065-
for (uint32_t i0 = 0; i0 < n_kv; ++i0) {
8065+
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
80668066
const auto & cell0 = kv_self.cells[i0];
80678067

80688068
if (!cell0.is_empty()) {
8069-
ids[i0] = id;
8069+
ids[i0] = i0;
8070+
8071+
continue;
8072+
}
8073+
8074+
// found a hole - fill it with data from the end of the cache
8075+
8076+
// determine the size of the hole
8077+
uint32_t nh = 1;
8078+
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
8079+
nh++;
8080+
}
8081+
8082+
// starting from the end, find nh non-empty cells
8083+
uint32_t nf = 0;
8084+
uint32_t is = n_kv - 1;
8085+
for (; is > i0; --is) {
8086+
const auto & cell1 = kv_self.cells[is];
8087+
8088+
if (cell1.is_empty() || ids[is] != n_kv) {
8089+
continue;
8090+
}
80708091

8071-
if (i0 != id) {
8072-
kv_self.cells[id] = cell0;
8073-
n_moves++;
8092+
// non-empty cell which is not yet moved
8093+
nf++;
8094+
if (nf == nh) {
8095+
break;
8096+
}
8097+
}
8098+
8099+
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
8100+
8101+
nf = 0;
8102+
8103+
// go back and move the nf cells to the hole
8104+
for (uint32_t i1 = is; i1 < n_kv; ++i1) {
8105+
const auto & cell1 = kv_self.cells[i1];
8106+
8107+
if (cell1.is_empty() || ids[i1] != n_kv) {
8108+
continue;
80748109
}
80758110

8076-
id++;
8111+
ids[i1] = i0 + nf;
8112+
8113+
// move the cell meta data
8114+
kv_self.cells[i0 + nf] = cell1;
8115+
8116+
n_moves++;
8117+
nf++;
80778118
}
8119+
8120+
LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
8121+
8122+
i0 += nh - 1;
80788123
}
80798124

80808125
if (n_moves == 0) {
@@ -8083,11 +8128,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
80838128

80848129
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
80858130

8086-
kv_self.head = id;
8087-
kv_self.used = id;
8131+
kv_self.head = n_used;
8132+
kv_self.used = n_used;
80888133

80898134
// zero the rest of the cells
8090-
for (uint32_t i = id; i < n_kv; ++i) {
8135+
for (uint32_t i = n_used; i < n_kv; ++i) {
80918136
kv_self.cells[i] = llama_kv_cell();
80928137
}
80938138

0 commit comments

Comments
 (0)