Skip to content

Commit 8086b61

Browse files
committed
kv_cache : minor
1 parent 254de88 commit 8086b61

File tree

3 files changed

+47
-27
lines changed

3 files changed

+47
-27
lines changed

src/llama-kv-cache.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,22 @@ bool llama_kv_cache::init(
7373
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
7474
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
7575

76-
LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
76+
const char * dev_name = "CPU";
7777

7878
ggml_backend_buffer_type_t buft;
7979
if (offload) {
8080
auto * dev = model.dev_layer(i);
8181
buft = ggml_backend_dev_buffer_type(dev);
82+
83+
dev_name = ggml_backend_dev_name(dev);
8284
} else {
8385
buft = ggml_backend_cpu_buffer_type();
8486
}
85-
ggml_context * ctx = ctx_for_buft(buft);
8687

88+
LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
89+
i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
90+
91+
ggml_context * ctx = ctx_for_buft(buft);
8792
if (!ctx) {
8893
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
8994
return false;
@@ -134,14 +139,13 @@ size_t llama_kv_cache::total_size() const {
134139
return size;
135140
}
136141

137-
// TODO: better data structures to reduce the cost of this operation
138-
llama_pos llama_kv_cache::max_pos() const {
139-
llama_pos max_pos = -1;
142+
llama_pos llama_kv_cache::pos_max() const {
143+
llama_pos pos_max = -1;
140144
for (const auto & cell : cells) {
141-
max_pos = std::max(max_pos, cell.pos);
145+
pos_max = std::max(pos_max, cell.pos);
142146
}
143147

144-
return max_pos;
148+
return pos_max;
145149
}
146150

147151
void llama_kv_cache::clear() {
@@ -672,6 +676,26 @@ uint32_t llama_kv_cache::cell_max() const {
672676
return 0;
673677
}
674678

679+
size_t llama_kv_cache::size_k_bytes() const {
680+
size_t size_k_bytes = 0;
681+
682+
for (const auto & k : k_l) {
683+
size_k_bytes += ggml_nbytes(k);
684+
}
685+
686+
return size_k_bytes;
687+
}
688+
689+
size_t llama_kv_cache::size_v_bytes() const {
690+
size_t size_v_bytes = 0;
691+
692+
for (const auto & v : v_l) {
693+
size_v_bytes += ggml_nbytes(v);
694+
}
695+
696+
return size_v_bytes;
697+
}
698+
675699
void llama_kv_cache_clear(llama_kv_cache * kv) {
676700
kv->clear();
677701
}

src/llama-kv-cache.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,11 @@ struct llama_kv_cache {
6161
// computed before each graph build
6262
uint32_t n = 0;
6363

64-
ggml_type type_k = GGML_TYPE_F16;
65-
ggml_type type_v = GGML_TYPE_F16;
66-
6764
std::vector<llama_kv_cell> cells;
6865

6966
std::vector<struct ggml_tensor *> k_l; // per layer
7067
std::vector<struct ggml_tensor *> v_l;
7168

72-
std::vector<ggml_context_ptr> ctxs;
73-
std::vector<ggml_backend_buffer_ptr> bufs;
74-
7569
// TODO: become constructor
7670
bool init(
7771
const llama_model & model,
@@ -86,7 +80,7 @@ struct llama_kv_cache {
8680
size_t total_size() const;
8781

8882
// TODO: better data structures to reduce the cost of this operation
89-
llama_pos max_pos() const;
83+
llama_pos pos_max() const;
9084

9185
void clear();
9286

@@ -112,6 +106,16 @@ struct llama_kv_cache {
112106

113107
// find how many cells are currently in use
114108
uint32_t cell_max() const;
109+
110+
size_t size_k_bytes() const;
111+
size_t size_v_bytes() const;
112+
113+
private:
114+
ggml_type type_k = GGML_TYPE_F16;
115+
ggml_type type_v = GGML_TYPE_F16;
116+
117+
std::vector<ggml_context_ptr> ctxs;
118+
std::vector<ggml_backend_buffer_ptr> bufs;
115119
};
116120

117121
//

src/llama.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,7 +1973,7 @@ struct llm_build_context {
19731973
if (il == n_layer - 1) {
19741974
// skip computing output for unused tokens
19751975
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
1976-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1976+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
19771977
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
19781978
}
19791979

@@ -8456,7 +8456,7 @@ static int llama_decode_impl(
84568456
}
84578457

84588458
// temporary allocate memory for the input batch if needed
8459-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
8459+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.pos_max() + 1);
84608460

84618461
const llama_batch & batch = batch_allocr.batch;
84628462
const uint32_t n_tokens_all = batch.n_tokens;
@@ -8792,7 +8792,7 @@ static int llama_encode_impl(
87928792
}
87938793

87948794
// temporary allocate memory for the input batch if needed
8795-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
8795+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.pos_max() + 1);
87968796

87978797
const llama_batch & batch = batch_allocr.batch;
87988798
const uint32_t n_tokens = batch.n_tokens;
@@ -9706,16 +9706,8 @@ struct llama_context * llama_init_from_model(
97069706
}
97079707

97089708
{
9709-
size_t memory_size_k = 0;
9710-
size_t memory_size_v = 0;
9711-
9712-
for (auto & k : ctx->kv_self.k_l) {
9713-
memory_size_k += ggml_nbytes(k);
9714-
}
9715-
9716-
for (auto & v : ctx->kv_self.v_l) {
9717-
memory_size_v += ggml_nbytes(v);
9718-
}
9709+
const size_t memory_size_k = ctx->kv_self.size_k_bytes();
9710+
const size_t memory_size_v = ctx->kv_self.size_v_bytes();
97199711

97209712
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
97219713
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),

0 commit comments

Comments
 (0)