Skip to content

Commit b926f40

Browse files
committed
context : llama_kv_cache -> llama_memory_i
ggml-ci
1 parent 367cf85 commit b926f40

File tree

4 files changed

+61
-51
lines changed

4 files changed

+61
-51
lines changed

src/llama-context.cpp

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -177,65 +177,35 @@ llama_context::llama_context(
177177
}
178178

179179
// init the memory module
180-
// TODO: for now, always create a unified KV cache
181180
if (!hparams.vocab_only) {
182-
uint32_t kv_size = 0;
183-
ggml_type type_k = params.type_k;
184-
ggml_type type_v = params.type_v;
181+
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
185182

186183
if (!llama_model_is_recurrent(&model)) {
187-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
188-
189184
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams));
190185

191186
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
192187

193-
kv_size = cparams.n_ctx;
194-
type_k = params.type_k;
195-
type_v = params.type_v;
196-
197188
llama_memory_params params_mem = {
198-
/*.type_k =*/ type_k,
199-
/*.type_v =*/ type_v,
189+
/*.type_k =*/ params.type_k,
190+
/*.type_v =*/ params.type_v,
200191
/*.v_trans =*/ !cparams.flash_attn,
201192
/*.offload_kqv =*/ cparams.offload_kqv,
202-
/*.kv_size =*/ kv_size,
193+
/*.kv_size =*/ cparams.n_ctx,
203194
};
204195

205-
auto * kv = static_cast<llama_kv_cache_unified *>(model.create_memory(params_mem));
206-
207-
kv_self.reset(kv);
196+
memory.reset(model.create_memory(params_mem));
208197
} else {
209-
// Mamba needs at least as many KV cells as there are sequences kept at any time
210-
kv_size = std::max((uint32_t) 1, params.n_seq_max);
211-
// it's probably best to keep as much precision as possible for the states
212-
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
213-
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
214-
215198
llama_memory_params params_mem = {
216-
/*.type_k =*/ type_k,
217-
/*.type_v =*/ type_v,
199+
/*.type_k =*/ GGML_TYPE_F32, // required by ggml_ssm_conv for Mamba's conv_states
200+
/*.type_v =*/ GGML_TYPE_F32, // required by ggml_ssm_scan for Mamba's ssm_states
218201
/*.v_trans =*/ false, // unused
219-
/*.offload_kqv =*/ params.offload_kqv,
220-
/*.kv_size =*/ kv_size,
202+
/*.offload_kqv =*/ cparams.offload_kqv,
203+
/*.kv_size =*/ std::max((uint32_t) 1, params.n_seq_max), // Mamba needs at least as many KV cells as there are sequences kept at any time
221204
};
222205

223-
auto * kv = static_cast<llama_kv_cache_recurrent *>(model.create_memory(params_mem));
224-
225-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
226-
227-
kv_self.reset(kv);
206+
memory.reset(model.create_memory(params_mem));
228207
}
229208

230-
{
231-
const size_t memory_size_k = kv_self->size_k_bytes();
232-
const size_t memory_size_v = kv_self->size_v_bytes();
233-
234-
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
235-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
236-
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
237-
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
238-
}
239209
}
240210

241211
// init backends
@@ -326,6 +296,8 @@ llama_context::llama_context(
326296
int n_nodes_tg = -1;
327297

328298
// simulate full KV cache
299+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
300+
329301
kv_self->set_full();
330302

331303
cross.v_embd.clear();
@@ -477,11 +449,13 @@ uint32_t llama_context::n_threads_batch() const {
477449
}
478450

479451
llama_kv_cache * llama_context::get_kv_self() {
480-
return kv_self.get();
452+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
453+
return kv_self;
481454
}
482455

483456
const llama_kv_cache * llama_context::get_kv_self() const {
484-
return kv_self.get();
457+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
458+
return kv_self;
485459
}
486460

487461
ggml_tensor * llama_context::build_rope_shift(
@@ -567,7 +541,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
567541

568542
//GGML_ASSERT(kv_self->size == n_ctx);
569543

570-
const auto * kv = static_cast<const llama_kv_cache_unified *>(kv_self.get());
544+
const auto * kv = static_cast<const llama_kv_cache_unified *>(memory.get());
571545

572546
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
573547

@@ -609,7 +583,7 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
609583
ggml_cgraph * gf) const {
610584
auto res = std::make_unique<llm_graph_result>();
611585

612-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
586+
auto * kv = static_cast<llama_kv_cache_unified *>(memory.get());
613587

614588
const auto & hparams = model.hparams;
615589

@@ -755,6 +729,8 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
755729
void llama_context::kv_self_update() {
756730
bool need_reserve = false;
757731

732+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
733+
758734
if (kv_self->get_has_shift()) {
759735
if (!kv_self->get_can_shift()) {
760736
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
@@ -780,7 +756,7 @@ void llama_context::kv_self_update() {
780756
}
781757

782758
{
783-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
759+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
784760

785761
kv->has_shift = false;
786762

@@ -794,7 +770,7 @@ void llama_context::kv_self_update() {
794770
if (kv_self->get_do_defrag()) {
795771
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
796772

797-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
773+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
798774

799775
if (kv->defrag_prepare(graph_max_nodes())) {
800776
ggml_backend_sched_reset(sched.get());
@@ -1043,6 +1019,8 @@ int llama_context::encode(llama_batch & inp_batch) {
10431019
return -1;
10441020
}
10451021

1022+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1023+
10461024
// temporary allocate memory for the input batch if needed
10471025
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
10481026
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
@@ -1208,6 +1186,8 @@ int llama_context::decode(llama_batch & inp_batch) {
12081186
return -1;
12091187
}
12101188

1189+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1190+
12111191
// temporary allocate memory for the input batch if needed
12121192
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
12131193
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
@@ -1222,7 +1202,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12221202
const int64_t n_tokens_all = batch.n_tokens;
12231203
const int64_t n_embd = hparams.n_embd;
12241204

1225-
llama_kv_cache_guard kv_guard(kv_self.get());
1205+
llama_kv_cache_guard kv_guard(kv_self);
12261206

12271207
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
12281208

@@ -1326,7 +1306,7 @@ int llama_context::decode(llama_batch & inp_batch) {
13261306
}
13271307

13281308
if (!is_recurrent) {
1329-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
1309+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
13301310

13311311
// a heuristic, to avoid attending the full cache if it is not yet utilized
13321312
// after enough generations, the benefit from this heuristic disappears
@@ -1478,7 +1458,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14781458

14791459
// decide if we need to defrag the kv cache
14801460
if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1481-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
1461+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
14821462

14831463
// - do not defrag small contexts (i.e. < 2048 tokens)
14841464
// - count the padding towards the number of used tokens
@@ -1651,7 +1631,7 @@ llm_graph_result_ptr llama_context::graph_build(
16511631
/*.backend_cpu =*/ backend_cpu,
16521632
/*.cvec =*/ &cvec,
16531633
/*.loras =*/ &loras,
1654-
/*.memory =*/ kv_self.get(),
1634+
/*.memory =*/ memory.get(),
16551635
/*.cross =*/ &cross,
16561636
/*.n_outputs =*/ n_outputs,
16571637
/*.cb =*/ graph_get_cb(),
@@ -2110,6 +2090,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
21102090
}
21112091

21122092
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
2093+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2094+
21132095
kv_self->state_write(io);
21142096

21152097
return io.n_bytes();
@@ -2194,6 +2176,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
21942176
}
21952177

21962178
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2179+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2180+
21972181
kv_self->state_read(io);
21982182

21992183
return io.n_bytes();
@@ -2202,6 +2186,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
22022186
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
22032187
GGML_UNUSED(seq_id);
22042188

2189+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2190+
22052191
kv_self->state_write(io, seq_id);
22062192

22072193
return io.n_bytes();
@@ -2210,6 +2196,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
22102196
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
22112197
GGML_UNUSED(seq_id);
22122198

2199+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
2200+
22132201
kv_self->state_read(io, seq_id);
22142202

22152203
return io.n_bytes();

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ struct llama_context {
200200

201201
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
202202

203-
std::unique_ptr<llama_kv_cache> kv_self;
203+
std::unique_ptr<llama_memory_i> memory;
204204

205205
// TODO: remove
206206
bool logits_all = false;

src/llama-kv-cache.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,16 @@ llama_kv_cache_unified::llama_kv_cache_unified(
100100
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
101101
bufs.emplace_back(buf);
102102
}
103+
104+
{
105+
const size_t memory_size_k = size_k_bytes();
106+
const size_t memory_size_v = size_v_bytes();
107+
108+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
109+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
110+
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
111+
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
112+
}
103113
}
104114

105115
int32_t llama_kv_cache_unified::get_n_tokens() const {
@@ -1078,6 +1088,16 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
10781088
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
10791089
bufs.emplace_back(buf);
10801090
}
1091+
1092+
{
1093+
const size_t memory_size_k = size_k_bytes();
1094+
const size_t memory_size_v = size_v_bytes();
1095+
1096+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
1097+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
1098+
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
1099+
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1100+
}
10811101
}
10821102

10831103
int32_t llama_kv_cache_recurrent::get_n_tokens() const {

src/llama-memory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ struct llama_memory_params {
2020
// the KV cache is a type of LLM memory, but there can be other types
2121
class llama_memory_i {
2222
public:
23+
virtual ~llama_memory_i() = default;
24+
2325
virtual void clear() = 0;
2426
virtual void defrag() = 0;
2527

0 commit comments

Comments
 (0)