Skip to content

Commit b0d6b66

Browse files
committed
llama : kv cache
ggml-ci
1 parent 6eaea63 commit b0d6b66

File tree

8 files changed

+826
-703
lines changed

8 files changed

+826
-703
lines changed

include/llama.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@ extern "C" {
556556
// KV cache
557557
//
558558

559+
// TODO: remove llama_kv_cache_view_* API
560+
559561
// Information associated with an individual cell in the KV cache view.
560562
struct llama_kv_cache_view_cell {
561563
// The position for this cell. Takes KV cache shifts into account.
@@ -602,8 +604,11 @@ extern "C" {
602604
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
603605

604606
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
607+
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
605608
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
606609

610+
///
611+
607612
// Returns the number of tokens in the KV cache (slow, use only for debug)
608613
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
609614
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
@@ -673,6 +678,9 @@ extern "C" {
673678
struct llama_context * ctx,
674679
llama_seq_id seq_id);
675680

681+
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
682+
// how to avoid this?
683+
676684
// Defragment the KV cache
677685
// This will be applied:
678686
// - lazily on next llama_decode()

src/llama-context.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ struct llama_data_write {
106106
}
107107

108108
void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
109-
110109
for (const auto & range : cell_ranges) {
111110
for (uint32_t i = range.first; i < range.second; ++i) {
112111
const auto & cell = kv_self.cells[i];

src/llama-context.h

Lines changed: 1 addition & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "llama-impl.h"
44
#include "llama-batch.h"
5+
#include "llama-cparams.h"
56
#include "llama-model.h"
67
#include "llama-kv-cache.h"
78
#include "llama-adapter.h"
@@ -13,38 +14,6 @@
1314
#include <vector>
1415
#include <set>
1516

16-
struct llama_cparams {
17-
uint32_t n_ctx; // context size used during inference
18-
uint32_t n_batch;
19-
uint32_t n_ubatch;
20-
uint32_t n_seq_max;
21-
int n_threads; // number of threads to use for generation
22-
int n_threads_batch; // number of threads to use for batch processing
23-
24-
float rope_freq_base;
25-
float rope_freq_scale;
26-
27-
uint32_t n_ctx_orig_yarn;
28-
// These hyperparameters are not exposed in GGUF, because all
29-
// existing YaRN models use the same values for them.
30-
float yarn_ext_factor;
31-
float yarn_attn_factor;
32-
float yarn_beta_fast;
33-
float yarn_beta_slow;
34-
float defrag_thold;
35-
36-
bool embeddings;
37-
bool causal_attn;
38-
bool offload_kqv;
39-
bool flash_attn;
40-
bool no_perf;
41-
42-
enum llama_pooling_type pooling_type;
43-
44-
ggml_backend_sched_eval_callback cb_eval;
45-
void * cb_eval_user_data;
46-
};
47-
4817
struct llama_context {
4918
llama_context(const llama_model & model)
5019
: model(model)
@@ -140,112 +109,6 @@ struct llama_context {
140109
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
141110
};
142111

143-
static bool llama_kv_cache_init(
144-
struct llama_kv_cache & cache,
145-
const llama_context * ctx,
146-
ggml_type type_k,
147-
ggml_type type_v,
148-
uint32_t kv_size,
149-
bool offload) {
150-
const llama_model & model = ctx->model;
151-
const llama_cparams & cparams = ctx->cparams;
152-
153-
const struct llama_hparams & hparams = model.hparams;
154-
155-
const int32_t n_layer = hparams.n_layer;
156-
157-
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
158-
159-
cache.has_shift = false;
160-
161-
cache.recurrent = llama_model_is_recurrent(&model);
162-
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
163-
164-
cache.head = 0;
165-
cache.size = kv_size;
166-
cache.used = 0;
167-
168-
cache.type_k = type_k;
169-
cache.type_v = type_v;
170-
171-
cache.cells.clear();
172-
cache.cells.resize(kv_size);
173-
174-
// create a context for each buffer type
175-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
176-
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
177-
auto it = ctx_map.find(buft);
178-
if (it == ctx_map.end()) {
179-
struct ggml_init_params params = {
180-
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
181-
/*.mem_buffer =*/ NULL,
182-
/*.no_alloc =*/ true,
183-
};
184-
ggml_context * ctx = ggml_init(params);
185-
if (!ctx) {
186-
return nullptr;
187-
}
188-
ctx_map[buft] = ctx;
189-
cache.ctxs.emplace_back(ctx);
190-
return ctx;
191-
}
192-
return it->second;
193-
};
194-
195-
cache.k_l.reserve(n_layer);
196-
cache.v_l.reserve(n_layer);
197-
198-
for (int i = 0; i < n_layer; i++) {
199-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
200-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
201-
202-
LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
203-
204-
ggml_backend_buffer_type_t buft;
205-
if (offload) {
206-
auto * dev = model.dev_layer.at(i).dev;
207-
buft = ggml_backend_dev_buffer_type(dev);
208-
} else {
209-
buft = ggml_backend_cpu_buffer_type();
210-
}
211-
ggml_context * ctx = ctx_for_buft(buft);
212-
213-
if (!ctx) {
214-
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
215-
return false;
216-
}
217-
218-
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
219-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
220-
ggml_format_name(k, "cache_k_l%d", i);
221-
ggml_format_name(v, "cache_v_l%d", i);
222-
cache.k_l.push_back(k);
223-
cache.v_l.push_back(v);
224-
}
225-
226-
// allocate tensors and initialize the buffers to avoid NaNs in the padding
227-
for (auto it : ctx_map) {
228-
auto * buft = it.first;
229-
auto * ctx = it.second;
230-
231-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
232-
if (!buf) {
233-
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
234-
return false;
235-
}
236-
ggml_backend_buffer_clear(buf, 0);
237-
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
238-
cache.bufs.emplace_back(buf);
239-
}
240-
241-
return true;
242-
}
243-
244-
static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
245-
// the FA kernels require padding to avoid extra runtime boundary checks
246-
return cparams.flash_attn ? 256u : 32u;
247-
}
248-
249112
// Make sure enough space is available for outputs.
250113
// Returns max number of outputs for which space was reserved.
251114
static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {

src/llama-cparams.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "llama-cparams.h"

src/llama-cparams.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include "llama.h"
4+
5+
#include <cstdint>
6+
7+
struct llama_cparams {
8+
uint32_t n_ctx; // context size used during inference
9+
uint32_t n_batch;
10+
uint32_t n_ubatch;
11+
uint32_t n_seq_max;
12+
int n_threads; // number of threads to use for generation
13+
int n_threads_batch; // number of threads to use for batch processing
14+
15+
float rope_freq_base;
16+
float rope_freq_scale;
17+
18+
uint32_t n_ctx_orig_yarn;
19+
// These hyperparameters are not exposed in GGUF, because all
20+
// existing YaRN models use the same values for them.
21+
float yarn_ext_factor;
22+
float yarn_attn_factor;
23+
float yarn_beta_fast;
24+
float yarn_beta_slow;
25+
float defrag_thold;
26+
27+
bool embeddings;
28+
bool causal_attn;
29+
bool offload_kqv;
30+
bool flash_attn;
31+
bool no_perf;
32+
33+
enum llama_pooling_type pooling_type;
34+
35+
ggml_backend_sched_eval_callback cb_eval;
36+
void * cb_eval_user_data;
37+
};

0 commit comments

Comments
 (0)