Skip to content

Commit 3e8eceb

Browse files
committed
context : llama_kv_cache -> llama_memory_i
ggml-ci
1 parent cfe64ce commit 3e8eceb

File tree

4 files changed

+61
-51
lines changed

4 files changed

+61
-51
lines changed

src/llama-context.cpp

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -177,65 +177,35 @@ llama_context::llama_context(
177177
}
178178

179179
// init the memory module
180-
// TODO: for now, always create a unified KV cache
181180
if (!hparams.vocab_only) {
182-
uint32_t kv_size = 0;
183-
ggml_type type_k = params.type_k;
184-
ggml_type type_v = params.type_v;
181+
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
185182

186183
if (!llama_model_is_recurrent(&model)) {
187-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
188-
189184
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams));
190185

191186
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
192187

193-
kv_size = cparams.n_ctx;
194-
type_k = params.type_k;
195-
type_v = params.type_v;
196-
197188
llama_memory_params params_mem = {
198-
/*.type_k =*/ type_k,
199-
/*.type_v =*/ type_v,
189+
/*.type_k =*/ params.type_k,
190+
/*.type_v =*/ params.type_v,
200191
/*.v_trans =*/ !cparams.flash_attn,
201192
/*.offload_kqv =*/ cparams.offload_kqv,
202-
/*.kv_size =*/ kv_size,
193+
/*.kv_size =*/ cparams.n_ctx,
203194
};
204195

205-
auto * kv = static_cast<llama_kv_cache_unified *>(model.create_memory(params_mem));
206-
207-
kv_self.reset(kv);
196+
memory.reset(model.create_memory(params_mem));
208197
} else {
209-
// Mamba needs at least as many KV cells as there are sequences kept at any time
210-
kv_size = std::max((uint32_t) 1, params.n_seq_max);
211-
// it's probably best to keep as much precision as possible for the states
212-
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
213-
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
214-
215198
llama_memory_params params_mem = {
216-
/*.type_k =*/ type_k,
217-
/*.type_v =*/ type_v,
199+
/*.type_k =*/ GGML_TYPE_F32, // required by ggml_ssm_conv for Mamba's conv_states
200+
/*.type_v =*/ GGML_TYPE_F32, // required by ggml_ssm_scan for Mamba's ssm_states
218201
/*.v_trans =*/ false, // unused
219-
/*.offload_kqv =*/ params.offload_kqv,
220-
/*.kv_size =*/ kv_size,
202+
/*.offload_kqv =*/ cparams.offload_kqv,
203+
/*.kv_size =*/ std::max((uint32_t) 1, params.n_seq_max), // Mamba needs at least as many KV cells as there are sequences kept at any time
221204
};
222205

223-
auto * kv = static_cast<llama_kv_cache_recurrent *>(model.create_memory(params_mem));
224-
225-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
226-
227-
kv_self.reset(kv);
206+
memory.reset(model.create_memory(params_mem));
228207
}
229208

230-
{
231-
const size_t memory_size_k = kv_self->size_k_bytes();
232-
const size_t memory_size_v = kv_self->size_v_bytes();
233-
234-
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
235-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
236-
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
237-
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
238-
}
239209
}
240210

241211
// init backends
@@ -326,6 +296,8 @@ llama_context::llama_context(
326296
int n_nodes_tg = -1;
327297

328298
// simulate full KV cache
299+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
300+
329301
kv_self->set_full();
330302

331303
cross.v_embd.clear();
@@ -477,11 +449,13 @@ uint32_t llama_context::n_threads_batch() const {
477449
}
478450

479451
llama_kv_cache * llama_context::get_kv_self() {
480-
return kv_self.get();
452+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
453+
return kv_self;
481454
}
482455

483456
const llama_kv_cache * llama_context::get_kv_self() const {
484-
return kv_self.get();
457+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
458+
return kv_self;
485459
}
486460

487461
ggml_tensor * llama_context::build_rope_shift(
@@ -578,7 +552,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
578552

579553
//GGML_ASSERT(kv_self->size == n_ctx);
580554

581-
const auto * kv = static_cast<const llama_kv_cache_unified *>(kv_self.get());
555+
const auto * kv = static_cast<const llama_kv_cache_unified *>(memory.get());
582556

583557
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
584558

@@ -620,7 +594,7 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
620594
ggml_cgraph * gf) const {
621595
auto res = std::make_unique<llm_graph_result>();
622596

623-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
597+
auto * kv = static_cast<llama_kv_cache_unified *>(memory.get());
624598

625599
const auto & hparams = model.hparams;
626600

@@ -766,6 +740,8 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
766740
void llama_context::kv_self_update() {
767741
bool need_reserve = false;
768742

743+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
744+
769745
if (kv_self->get_has_shift()) {
770746
if (!kv_self->get_can_shift()) {
771747
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
@@ -791,7 +767,7 @@ void llama_context::kv_self_update() {
791767
}
792768

793769
{
794-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
770+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
795771

796772
kv->has_shift = false;
797773

@@ -805,7 +781,7 @@ void llama_context::kv_self_update() {
805781
if (kv_self->get_do_defrag()) {
806782
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
807783

808-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
784+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
809785

810786
if (kv->defrag_prepare(graph_max_nodes())) {
811787
ggml_backend_sched_reset(sched.get());
@@ -1054,6 +1030,8 @@ int llama_context::encode(llama_batch & inp_batch) {
10541030
return -1;
10551031
}
10561032

1033+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1034+
10571035
// temporary allocate memory for the input batch if needed
10581036
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
10591037
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
@@ -1219,6 +1197,8 @@ int llama_context::decode(llama_batch & inp_batch) {
12191197
return -1;
12201198
}
12211199

1200+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1201+
12221202
// temporary allocate memory for the input batch if needed
12231203
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
12241204
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
@@ -1233,7 +1213,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12331213
const int64_t n_tokens_all = batch.n_tokens;
12341214
const int64_t n_embd = hparams.n_embd;
12351215

1236-
llama_kv_cache_guard kv_guard(kv_self.get());
1216+
llama_kv_cache_guard kv_guard(kv_self);
12371217

12381218
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
12391219

@@ -1337,7 +1317,7 @@ int llama_context::decode(llama_batch & inp_batch) {
13371317
}
13381318

13391319
if (!is_recurrent) {
1340-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
1320+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
13411321

13421322
// a heuristic, to avoid attending the full cache if it is not yet utilized
13431323
// after enough generations, the benefit from this heuristic disappears
@@ -1489,7 +1469,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14891469

14901470
// decide if we need to defrag the kv cache
14911471
if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1492-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
1472+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
14931473

14941474
// - do not defrag small contexts (i.e. < 2048 tokens)
14951475
// - count the padding towards the number of used tokens
@@ -1662,7 +1642,7 @@ llm_graph_result_ptr llama_context::graph_build(
16621642
/*.backend_cpu =*/ backend_cpu,
16631643
/*.cvec =*/ &cvec,
16641644
/*.loras =*/ &loras,
1665-
/*.memory =*/ kv_self.get(),
1645+
/*.memory =*/ memory.get(),
16661646
/*.cross =*/ &cross,
16671647
/*.n_outputs =*/ n_outputs,
16681648
/*.cb =*/ graph_get_cb(),
@@ -2121,6 +2101,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
21212101
}
21222102

21232103
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
2104+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2105+
21242106
kv_self->state_write(io);
21252107

21262108
return io.n_bytes();
@@ -2205,6 +2187,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
22052187
}
22062188

22072189
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2190+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2191+
22082192
kv_self->state_read(io);
22092193

22102194
return io.n_bytes();
@@ -2213,6 +2197,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
22132197
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
22142198
GGML_UNUSED(seq_id);
22152199

2200+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2201+
22162202
kv_self->state_write(io, seq_id);
22172203

22182204
return io.n_bytes();
@@ -2221,6 +2207,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
22212207
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
22222208
GGML_UNUSED(seq_id);
22232209

2210+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2211+
22242212
kv_self->state_read(io, seq_id);
22252213

22262214
return io.n_bytes();

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ struct llama_context {
201201

202202
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203203

204-
std::unique_ptr<llama_kv_cache> kv_self;
204+
std::unique_ptr<llama_memory_i> memory;
205205

206206
// TODO: remove
207207
bool logits_all = false;

src/llama-kv-cache.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,16 @@ llama_kv_cache_unified::llama_kv_cache_unified(
100100
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
101101
bufs.emplace_back(buf);
102102
}
103+
104+
{
105+
const size_t memory_size_k = size_k_bytes();
106+
const size_t memory_size_v = size_v_bytes();
107+
108+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
109+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
110+
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
111+
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
112+
}
103113
}
104114

105115
int32_t llama_kv_cache_unified::get_n_tokens() const {
@@ -1078,6 +1088,16 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
10781088
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
10791089
bufs.emplace_back(buf);
10801090
}
1091+
1092+
{
1093+
const size_t memory_size_k = size_k_bytes();
1094+
const size_t memory_size_v = size_v_bytes();
1095+
1096+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
1097+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
1098+
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
1099+
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1100+
}
10811101
}
10821102

10831103
int32_t llama_kv_cache_recurrent::get_n_tokens() const {

src/llama-memory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ struct llama_memory_params {
2020
// the KV cache is a type of LLM memory, but there can be other types
2121
class llama_memory_i {
2222
public:
23+
virtual ~llama_memory_i() = default;
24+
2325
virtual void clear() = 0;
2426
virtual void defrag() = 0;
2527

0 commit comments

Comments
 (0)