Skip to content

Commit 8c3a6a2

Browse files
committed
context : move memory creation logic to model
ggml-ci
1 parent 3e8eceb commit 8c3a6a2

File tree

4 files changed

+23
-38
lines changed

4 files changed

+23
-38
lines changed

src/llama-context.cpp

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -178,34 +178,12 @@ llama_context::llama_context(
178178

179179
// init the memory module
180180
if (!hparams.vocab_only) {
181-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
182-
183-
if (!llama_model_is_recurrent(&model)) {
184-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams));
185-
186-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
187-
188-
llama_memory_params params_mem = {
189-
/*.type_k =*/ params.type_k,
190-
/*.type_v =*/ params.type_v,
191-
/*.v_trans =*/ !cparams.flash_attn,
192-
/*.offload_kqv =*/ cparams.offload_kqv,
193-
/*.kv_size =*/ cparams.n_ctx,
194-
};
195-
196-
memory.reset(model.create_memory(params_mem));
197-
} else {
198-
llama_memory_params params_mem = {
199-
/*.type_k =*/ GGML_TYPE_F32, // required by ggml_ssm_conv for Mamba's conv_states
200-
/*.type_v =*/ GGML_TYPE_F32, // required by ggml_ssm_scan for Mamba's ssm_states
201-
/*.v_trans =*/ false, // unused
202-
/*.offload_kqv =*/ cparams.offload_kqv,
203-
/*.kv_size =*/ std::max((uint32_t) 1, params.n_seq_max), // Mamba needs at least as many KV cells as there are sequences kept at any time
204-
};
205-
206-
memory.reset(model.create_memory(params_mem));
207-
}
181+
llama_memory_params params_mem = {
182+
/*.type_k =*/ params.type_k,
183+
/*.type_v =*/ params.type_v,
184+
};
208185

186+
memory.reset(model.create_memory(cparams, params_mem));
209187
}
210188

211189
// init backends

src/llama-memory.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ struct llama_memory_params {
77
ggml_type type_k;
88
ggml_type type_v;
99

10-
bool v_trans;
11-
bool offload_kqv;
10+
//bool v_trans;
11+
//bool offload_kqv;
1212

13-
uint32_t kv_size;
13+
//uint32_t kv_size;
1414

1515
// other types of memory
1616
// ...

src/llama-model.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12764,10 +12764,10 @@ struct llm_build_bailingmoe : public llm_graph_context {
1276412764
}
1276512765
};
1276612766

12767-
llama_memory_i * llama_model::create_memory(const llama_memory_params & params) const {
12767+
llama_memory_i * llama_model::create_memory(llama_cparams & cparams, const llama_memory_params & params) const {
1276812768
llama_memory_i * res;
1276912769

12770-
const bool offload = params.offload_kqv;
12770+
const bool offload = cparams.offload_kqv;
1277112771

1277212772
auto get_buft = [this, offload](int il) {
1277312773
const char * dev_name = "CPU";
@@ -12787,6 +12787,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params)
1278712787
return buft;
1278812788
};
1278912789

12790+
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
12791+
1279012792
switch (arch) {
1279112793
case LLM_ARCH_MAMBA:
1279212794
case LLM_ARCH_RWKV6:
@@ -12800,12 +12802,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params)
1280012802
/*.get_rope_factors =*/ nullptr,
1280112803
/*.get_buft =*/ get_buft,
1280212804
},
12803-
params.type_k,
12804-
params.type_v,
12805-
params.kv_size);
12805+
GGML_TYPE_F32,
12806+
GGML_TYPE_F32,
12807+
std::max((uint32_t) 1, cparams.n_seq_max));
1280612808
} break;
1280712809
default:
1280812810
{
12811+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams));
12812+
12813+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
12814+
1280912815
res = new llama_kv_cache_unified(
1281012816
hparams,
1281112817
{
@@ -12825,8 +12831,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params)
1282512831
},
1282612832
params.type_k,
1282712833
params.type_v,
12828-
params.v_trans,
12829-
params.kv_size);
12834+
!cparams.flash_attn,
12835+
cparams.n_ctx);
1283012836
}
1283112837
}
1283212838

src/llama-model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,9 @@ struct llama_model {
390390

391391
const struct ggml_tensor * get_tensor(const char * name) const;
392392

393+
// note: can mutate `cparams`
393394
// TODO: move this to new llm_arch_model_i interface
394-
llama_memory_i * create_memory(const llama_memory_params & params) const;
395+
llama_memory_i * create_memory(llama_cparams & cparams, const llama_memory_params & params) const;
395396

396397
// TODO: move this to new llm_arch_model_i interface
397398
llm_graph_result_ptr build_graph(

0 commit comments

Comments
 (0)