Skip to content

Commit eae74d1

Browse files
committed
kv-cache : init -> contructor + add llama_memory_params
ggml-ci
1 parent b018bb9 commit eae74d1

File tree

6 files changed

+130
-153
lines changed

6 files changed

+130
-153
lines changed

src/llama-context.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,39 +184,47 @@ llama_context::llama_context(
184184
ggml_type type_v = params.type_v;
185185

186186
if (!llama_model_is_recurrent(&model)) {
187-
//kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
188-
auto * kv = static_cast<llama_kv_cache_unified *>(model.create_memory());
189-
190187
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
191188

192-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv->get_padding(cparams));
189+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams));
193190

194191
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
195192

196193
kv_size = cparams.n_ctx;
197194
type_k = params.type_k;
198195
type_v = params.type_v;
199196

200-
kv_self.reset(kv);
201-
} else {
202-
auto * kv = static_cast<llama_kv_cache_recurrent *>(model.create_memory());
197+
llama_memory_params params_mem = {
198+
/*.type_k =*/ type_k,
199+
/*.type_v =*/ type_v,
200+
/*.v_trans =*/ !cparams.flash_attn,
201+
/*.offload_kqv =*/ cparams.offload_kqv,
202+
/*.kv_size =*/ kv_size,
203+
};
203204

204-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
205+
auto * kv = static_cast<llama_kv_cache_unified *>(model.create_memory(params_mem));
205206

207+
kv_self.reset(kv);
208+
} else {
206209
// Mamba needs at least as many KV cells as there are sequences kept at any time
207210
kv_size = std::max((uint32_t) 1, params.n_seq_max);
208211
// it's probably best to keep as much precision as possible for the states
209212
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
210213
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
211214

212-
kv_self.reset(kv);
213-
}
215+
llama_memory_params params_mem = {
216+
/*.type_k =*/ type_k,
217+
/*.type_v =*/ type_v,
218+
/*.v_trans =*/ false, // unused
219+
/*.offload_kqv =*/ params.offload_kqv,
220+
/*.kv_size =*/ kv_size,
221+
};
214222

215-
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
216-
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
223+
auto * kv = static_cast<llama_kv_cache_recurrent *>(model.create_memory(params_mem));
217224

218-
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
219-
throw std::runtime_error("failed to initialize self-attention cache");
225+
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
226+
227+
kv_self.reset(kv);
220228
}
221229

222230
{

src/llama-kv-cache.cpp

Lines changed: 24 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,22 @@
1515
// llama_kv_cache_unified
1616
//
1717

18-
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
19-
}
18+
llama_kv_cache_unified::llama_kv_cache_unified(
19+
const llama_hparams & hparams,
20+
callbacks cbs,
21+
ggml_type type_k,
22+
ggml_type type_v,
23+
bool v_trans,
24+
uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans) {
2025

21-
bool llama_kv_cache_unified::init(
22-
const llama_model & model,
23-
const llama_cparams & cparams,
24-
ggml_type type_k,
25-
ggml_type type_v,
26-
uint32_t kv_size,
27-
bool offload) {
2826
const int32_t n_layer = hparams.n_layer;
2927

3028
has_shift = false;
3129

32-
GGML_ASSERT(!llama_model_is_recurrent(&model));
33-
34-
v_trans = !cparams.flash_attn;
3530
can_shift = true;
3631

37-
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
38-
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
32+
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
33+
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
3934

4035
head = 0;
4136
size = kv_size;
@@ -79,25 +74,11 @@ bool llama_kv_cache_unified::init(
7974
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
8075
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
8176

82-
const char * dev_name = "CPU";
83-
84-
ggml_backend_buffer_type_t buft;
85-
if (offload) {
86-
auto * dev = model.dev_layer(i);
87-
buft = ggml_backend_dev_buffer_type(dev);
88-
89-
dev_name = ggml_backend_dev_name(dev);
90-
} else {
91-
buft = ggml_backend_cpu_buffer_type();
92-
}
93-
94-
LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
95-
i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
77+
ggml_backend_buffer_type_t buft = cbs.get_buft(i);
9678

9779
ggml_context * ctx = ctx_for_buft(buft);
9880
if (!ctx) {
99-
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
100-
return false;
81+
throw std::runtime_error("failed to create ggml context for kv cache");
10182
}
10283

10384
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
@@ -115,15 +96,12 @@ bool llama_kv_cache_unified::init(
11596

11697
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
11798
if (!buf) {
118-
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
119-
return false;
99+
throw std::runtime_error("failed to allocate buffer for kv cache");
120100
}
121101
ggml_backend_buffer_clear(buf, 0);
122102
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
123103
bufs.emplace_back(buf);
124104
}
125-
126-
return true;
127105
}
128106

129107
int32_t llama_kv_cache_unified::get_n_tokens() const {
@@ -480,7 +458,7 @@ bool llama_kv_cache_unified::find_slot(
480458
return true;
481459
}
482460

483-
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
461+
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
484462
// the FA kernels require padding to avoid extra runtime boundary checks
485463
return cparams.flash_attn ? 256u : 32u;
486464
}
@@ -1021,24 +999,16 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1021999
// llama_kv_cache_recurrent
10221000
//
10231001

1024-
llama_kv_cache_recurrent::llama_kv_cache_recurrent(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
1025-
}
1026-
1027-
bool llama_kv_cache_recurrent::init(
1028-
const llama_model & model,
1029-
const llama_cparams & cparams,
1030-
ggml_type type_k,
1031-
ggml_type type_v,
1032-
uint32_t kv_size,
1033-
bool offload) {
1034-
GGML_UNUSED(cparams);
1035-
1002+
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1003+
const llama_hparams & hparams,
1004+
callbacks cbs,
1005+
ggml_type type_k,
1006+
ggml_type type_v,
1007+
uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)) {
10361008
const int32_t n_layer = hparams.n_layer;
10371009

1038-
GGML_ASSERT(llama_model_is_recurrent(&model));
1039-
1040-
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
1041-
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
1010+
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
1011+
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
10421012

10431013
head = 0;
10441014
size = kv_size;
@@ -1082,25 +1052,11 @@ bool llama_kv_cache_recurrent::init(
10821052
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
10831053
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
10841054

1085-
const char * dev_name = "CPU";
1086-
1087-
ggml_backend_buffer_type_t buft;
1088-
if (offload) {
1089-
auto * dev = model.dev_layer(i);
1090-
buft = ggml_backend_dev_buffer_type(dev);
1091-
1092-
dev_name = ggml_backend_dev_name(dev);
1093-
} else {
1094-
buft = ggml_backend_cpu_buffer_type();
1095-
}
1096-
1097-
LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
1098-
i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
1055+
ggml_backend_buffer_type_t buft = cbs.get_buft(i);
10991056

11001057
ggml_context * ctx = ctx_for_buft(buft);
11011058
if (!ctx) {
1102-
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
1103-
return false;
1059+
throw std::runtime_error("failed to create ggml context for kv cache");
11041060
}
11051061

11061062
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
@@ -1118,15 +1074,12 @@ bool llama_kv_cache_recurrent::init(
11181074

11191075
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
11201076
if (!buf) {
1121-
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
1122-
return false;
1077+
throw std::runtime_error("failed to allocate buffer for kv cache");
11231078
}
11241079
ggml_backend_buffer_clear(buf, 0);
11251080
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
11261081
bufs.emplace_back(buf);
11271082
}
1128-
1129-
return true;
11301083
}
11311084

11321085
int32_t llama_kv_cache_recurrent::get_n_tokens() const {
@@ -1558,11 +1511,6 @@ bool llama_kv_cache_recurrent::find_slot(
15581511
return n >= n_seqs;
15591512
}
15601513

1561-
uint32_t llama_kv_cache_recurrent::get_padding(const llama_cparams & cparams) const {
1562-
// the FA kernels require padding to avoid extra runtime boundary checks
1563-
return cparams.flash_attn ? 256u : 32u;
1564-
}
1565-
15661514
uint32_t llama_kv_cache_recurrent::cell_max() const {
15671515
for (uint32_t i = size; i > 0; --i) {
15681516
const llama_kv_cell & cell = cells[i - 1];

src/llama-kv-cache.h

Lines changed: 18 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,18 @@ struct llama_hparams;
1515
struct llama_ubatch;
1616

1717
struct llama_kv_cache : public llama_memory_i {
18+
// can be used to query data from the model if needed
19+
struct callbacks {
20+
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
21+
22+
// get the buffer type of layer il, can be used to offload KV cache layers to a different device
23+
std::function<ggml_backend_buffer_type_t (int il)> get_buft;
24+
};
25+
1826
virtual ~llama_kv_cache() = default;
1927

2028
using llama_memory_i::llama_memory_i;
2129

22-
// TODO: become constructor
23-
virtual bool init(
24-
const llama_model & model, // TODO: do not reference the model
25-
const llama_cparams & cparams,
26-
ggml_type type_k,
27-
ggml_type type_v,
28-
uint32_t kv_size,
29-
bool offload) = 0;
30-
3130
virtual void restore() = 0; // call if batch processing fails - restores the cache state
3231
virtual void commit() = 0; // call after successful batch processing - clears any pending state
3332

@@ -96,23 +95,13 @@ struct llama_kv_cell {
9695
// TODO: add notion of max sequences
9796
class llama_kv_cache_unified : public llama_kv_cache {
9897
public:
99-
// can be used to query data from the model if needed
100-
struct callbacks {
101-
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
102-
};
103-
104-
// TODO: become constructor
105-
bool init(
106-
const llama_model & model, // TODO: do not reference the model
107-
const llama_cparams & cparams,
108-
ggml_type type_k,
109-
ggml_type type_v,
110-
uint32_t kv_size,
111-
bool offload) override;
112-
11398
llama_kv_cache_unified(
11499
const llama_hparams & hparams,
115-
callbacks cbs);
100+
callbacks cbs,
101+
ggml_type type_k,
102+
ggml_type type_v,
103+
bool v_trans,
104+
uint32_t kv_size);
116105

117106
~llama_kv_cache_unified() = default;
118107

@@ -149,8 +138,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
149138
// to the first cell of the slot.
150139
bool find_slot(const llama_ubatch & batch) override;
151140

152-
// TODO: maybe not needed
153-
uint32_t get_padding(const llama_cparams & cparams) const;
141+
static uint32_t get_padding(const llama_cparams & cparams);
154142

155143
// find how many cells are currently in use
156144
uint32_t cell_max() const;
@@ -229,26 +217,15 @@ class llama_kv_cache_unified : public llama_kv_cache {
229217

230218
class llama_kv_cache_recurrent : public llama_kv_cache {
231219
public:
232-
// can be used to query data from the model if needed
233-
struct callbacks {
234-
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
235-
};
236-
237220
llama_kv_cache_recurrent(
238221
const llama_hparams & hparams,
239-
callbacks cbs);
222+
callbacks cbs,
223+
ggml_type type_k,
224+
ggml_type type_v,
225+
uint32_t kv_size);
240226

241227
~llama_kv_cache_recurrent() = default;
242228

243-
// TODO: become constructor
244-
bool init(
245-
const llama_model & model, // TODO: do not reference the model
246-
const llama_cparams & cparams,
247-
ggml_type type_k,
248-
ggml_type type_v,
249-
uint32_t kv_size,
250-
bool offload) override;
251-
252229
int32_t get_n_tokens() const override;
253230
int32_t get_used_cells() const override;
254231

@@ -282,9 +259,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
282259
// to the first cell of the slot.
283260
bool find_slot(const llama_ubatch & batch) override;
284261

285-
// TODO: maybe not needed
286-
uint32_t get_padding(const llama_cparams & cparams) const;
287-
288262
// find how many cells are currently in use
289263
uint32_t cell_max() const;
290264

src/llama-memory.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,20 @@
22

33
#include "llama.h"
44

5+
struct llama_memory_params {
6+
// kv cache
7+
ggml_type type_k;
8+
ggml_type type_v;
9+
10+
bool v_trans;
11+
bool offload_kqv;
12+
13+
uint32_t kv_size;
14+
15+
// other types of memory
16+
// ...
17+
};
18+
519
// general concept of LLM memory
620
// the KV cache is a type of LLM memory, but there can be other types
721
class llama_memory_i {

0 commit comments

Comments
 (0)