Skip to content

Commit e262947

Browse files
committed
common : add command-line arg to disable KV cache offloading
1 parent c80b8a2 commit e262947

File tree

4 files changed

+68
-51
lines changed

4 files changed

+68
-51
lines changed

common/common.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
498498
params.infill = true;
499499
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
500500
params.dump_kv_cache = true;
501+
} else if (arg == "-nkvo" || arg == "--no-kv-offload") {
502+
params.no_kv_offload = true;
501503
} else if (arg == "--multiline-input") {
502504
params.multiline_input = true;
503505
} else if (arg == "--simple-io") {
@@ -840,6 +842,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
840842
printf(" --verbose-prompt print prompt before generation\n");
841843
printf(" -dkvc, --dump-kv-cache\n");
842844
printf(" verbose print of the KV cache\n");
845+
printf(" -nkvo, --no-kv-offload\n");
846+
printf(" disable KV offload\n");
843847
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
844848
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
845849
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@@ -924,6 +928,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
924928
cparams.yarn_beta_fast = params.yarn_beta_fast;
925929
cparams.yarn_beta_slow = params.yarn_beta_slow;
926930
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
931+
cparams.offload_kqv = !params.no_kv_offload;
927932

928933
return cparams;
929934
}

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ struct gpt_params {
123123
bool verbose_prompt = false; // print prompt tokens before generation
124124
bool infill = false; // use infill mode
125125
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
126+
bool no_kv_offload = false; // disable KV offloading
126127

127128
// multimodal models (see examples/llava)
128129
std::string mmproj = ""; // path to multimodal projector

llama.cpp

Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,8 +1245,7 @@ struct llama_cparams {
12451245
float yarn_beta_slow;
12461246

12471247
bool mul_mat_q;
1248-
bool offload_k;
1249-
bool offload_v;
1248+
bool offload_kqv;
12501249

12511250
};
12521251

@@ -1526,8 +1525,7 @@ static bool llama_kv_cache_init(
15261525
ggml_type wtype,
15271526
uint32_t n_ctx,
15281527
int n_gpu_layers,
1529-
bool offload_k,
1530-
bool offload_v) {
1528+
bool offload) {
15311529
const uint32_t n_embd = hparams.n_embd_gqa();
15321530
const uint32_t n_layer = hparams.n_layer;
15331531

@@ -1574,11 +1572,9 @@ static bool llama_kv_cache_init(
15741572
cache.v_l.push_back(v);
15751573
#ifdef GGML_USE_CUBLAS
15761574
if (i >= i_gpu_start) {
1577-
if (offload_k) {
1575+
if (offload) {
15781576
ggml_cuda_assign_buffers_no_scratch(k);
15791577
vram_kv_cache += ggml_nbytes(k);
1580-
}
1581-
if (offload_v) {
15821578
ggml_cuda_assign_buffers_no_scratch(v);
15831579
vram_kv_cache += ggml_nbytes(v);
15841580
}
@@ -5101,6 +5097,7 @@ enum llm_offload_func_e {
51015097
OFFLOAD_FUNC_NOP,
51025098
OFFLOAD_FUNC,
51035099
OFFLOAD_FUNC_FRC, // force offload
5100+
OFFLOAD_FUNC_KQV,
51045101
OFFLOAD_FUNC_NR,
51055102
OFFLOAD_FUNC_EMB,
51065103
OFFLOAD_FUNC_OUT,
@@ -5204,38 +5201,38 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
52045201
{ "attn_norm", OFFLOAD_FUNC },
52055202
{ "attn_norm_2", OFFLOAD_FUNC },
52065203

5207-
{ "wqkv", OFFLOAD_FUNC },
5208-
{ "bqkv", OFFLOAD_FUNC },
5209-
{ "wqkv_clamped", OFFLOAD_FUNC },
5210-
5211-
{ "tmpk", OFFLOAD_FUNC },
5212-
{ "tmpq", OFFLOAD_FUNC },
5213-
{ "tmpv", OFFLOAD_FUNC },
5214-
{ "Kcur", OFFLOAD_FUNC },
5215-
{ "Qcur", OFFLOAD_FUNC },
5216-
{ "Vcur", OFFLOAD_FUNC },
5217-
5218-
{ "krot", OFFLOAD_FUNC },
5219-
{ "qrot", OFFLOAD_FUNC },
5220-
{ "kpass", OFFLOAD_FUNC },
5221-
{ "qpass", OFFLOAD_FUNC },
5222-
{ "krotated", OFFLOAD_FUNC },
5223-
{ "qrotated", OFFLOAD_FUNC },
5224-
5225-
{ "q", OFFLOAD_FUNC },
5226-
{ "k", OFFLOAD_FUNC },
5227-
{ "kq", OFFLOAD_FUNC },
5228-
{ "kq_scaled", OFFLOAD_FUNC },
5229-
{ "kq_scaled_alibi", OFFLOAD_FUNC },
5230-
{ "kq_masked", OFFLOAD_FUNC },
5231-
{ "kq_soft_max", OFFLOAD_FUNC },
5232-
{ "kq_soft_max_ext", OFFLOAD_FUNC },
5233-
{ "v", OFFLOAD_FUNC },
5234-
{ "kqv", OFFLOAD_FUNC },
5235-
{ "kqv_merged", OFFLOAD_FUNC },
5236-
{ "kqv_merged_cont", OFFLOAD_FUNC },
5237-
{ "kqv_wo", OFFLOAD_FUNC },
5238-
{ "kqv_out", OFFLOAD_FUNC },
5204+
{ "wqkv", OFFLOAD_FUNC_KQV },
5205+
{ "bqkv", OFFLOAD_FUNC_KQV },
5206+
{ "wqkv_clamped", OFFLOAD_FUNC_KQV },
5207+
5208+
{ "tmpk", OFFLOAD_FUNC_KQV },
5209+
{ "tmpq", OFFLOAD_FUNC_KQV },
5210+
{ "tmpv", OFFLOAD_FUNC_KQV },
5211+
{ "Kcur", OFFLOAD_FUNC_KQV },
5212+
{ "Qcur", OFFLOAD_FUNC_KQV },
5213+
{ "Vcur", OFFLOAD_FUNC_KQV },
5214+
5215+
{ "krot", OFFLOAD_FUNC_KQV },
5216+
{ "qrot", OFFLOAD_FUNC_KQV },
5217+
{ "kpass", OFFLOAD_FUNC_KQV },
5218+
{ "qpass", OFFLOAD_FUNC_KQV },
5219+
{ "krotated", OFFLOAD_FUNC_KQV },
5220+
{ "qrotated", OFFLOAD_FUNC_KQV },
5221+
5222+
{ "q", OFFLOAD_FUNC_KQV },
5223+
{ "k", OFFLOAD_FUNC_KQV },
5224+
{ "kq", OFFLOAD_FUNC_KQV },
5225+
{ "kq_scaled", OFFLOAD_FUNC_KQV },
5226+
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQV },
5227+
{ "kq_masked", OFFLOAD_FUNC_KQV },
5228+
{ "kq_soft_max", OFFLOAD_FUNC_KQV },
5229+
{ "kq_soft_max_ext", OFFLOAD_FUNC_KQV },
5230+
{ "v", OFFLOAD_FUNC_KQV },
5231+
{ "kqv", OFFLOAD_FUNC_KQV },
5232+
{ "kqv_merged", OFFLOAD_FUNC_KQV },
5233+
{ "kqv_merged_cont", OFFLOAD_FUNC_KQV },
5234+
{ "kqv_wo", OFFLOAD_FUNC_KQV },
5235+
{ "kqv_out", OFFLOAD_FUNC_KQV },
52395236

52405237
{ "ffn_inp", OFFLOAD_FUNC },
52415238
{ "ffn_norm", OFFLOAD_FUNC },
@@ -5429,11 +5426,13 @@ static struct ggml_cgraph * llama_build_graph(
54295426
#ifdef GGML_USE_CUBLAS
54305427
{ OFFLOAD_FUNC, "GPU (CUDA)" },
54315428
{ OFFLOAD_FUNC_FRC, "GPU (CUDA) FRC" },
5429+
{ OFFLOAD_FUNC_KQV, "GPU (CUDA) KQV" },
54325430
{ OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
54335431
{ OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
54345432
#else
54355433
{ OFFLOAD_FUNC, "CPU" },
54365434
{ OFFLOAD_FUNC_FRC, "CPU" },
5435+
{ OFFLOAD_FUNC_KQV, "CPU" },
54375436
{ OFFLOAD_FUNC_NR, "CPU" },
54385437
{ OFFLOAD_FUNC_EMB, "CPU" },
54395438
#endif // GGML_USE_CUBLAS
@@ -5458,7 +5457,6 @@ static struct ggml_cgraph * llama_build_graph(
54585457
switch (func_e) {
54595458
case OFFLOAD_FUNC_NOP:
54605459
case OFFLOAD_FUNC_OUT:
5461-
case OFFLOAD_FUNC_FRC:
54625460
break;
54635461
case OFFLOAD_FUNC:
54645462
if (n_gpu_layers < n_layer) {
@@ -5467,6 +5465,21 @@ static struct ggml_cgraph * llama_build_graph(
54675465
}
54685466
}
54695467
break;
5468+
case OFFLOAD_FUNC_FRC:
5469+
if (!lctx.cparams.offload_kqv) {
5470+
func_e = OFFLOAD_FUNC_NOP;
5471+
} break;
5472+
case OFFLOAD_FUNC_KQV:
5473+
if (!lctx.cparams.offload_kqv) {
5474+
func_e = OFFLOAD_FUNC_NOP;
5475+
} else {
5476+
if (n_gpu_layers < n_layer) {
5477+
if (il < i_gpu_start) {
5478+
func_e = OFFLOAD_FUNC_NOP;
5479+
}
5480+
}
5481+
}
5482+
break;
54705483
case OFFLOAD_FUNC_NR:
54715484
if (n_gpu_layers <= n_layer + 0) {
54725485
func_e = OFFLOAD_FUNC_NOP;
@@ -5493,6 +5506,7 @@ static struct ggml_cgraph * llama_build_graph(
54935506
case OFFLOAD_FUNC_NOP:
54945507
case OFFLOAD_FUNC_OUT: func = ggml_offload_nop; break;
54955508
case OFFLOAD_FUNC:
5509+
case OFFLOAD_FUNC_KQV:
54965510
case OFFLOAD_FUNC_FRC:
54975511
case OFFLOAD_FUNC_NR:
54985512
case OFFLOAD_FUNC_EMB: func = ggml_offload_gpu; break;
@@ -8567,8 +8581,7 @@ struct llama_context_params llama_context_default_params() {
85678581
/*.f16_kv =*/ true,
85688582
/*.logits_all =*/ false,
85698583
/*.embedding =*/ false,
8570-
/*.offload_k =*/ true,
8571-
/*.offload_q =*/ true,
8584+
/*.offload_kqv =*/ true,
85728585
};
85738586

85748587
return result;
@@ -8685,8 +8698,7 @@ struct llama_context * llama_new_context_with_model(
86858698
cparams.yarn_beta_fast = params.yarn_beta_fast;
86868699
cparams.yarn_beta_slow = params.yarn_beta_slow;
86878700
cparams.mul_mat_q = params.mul_mat_q;
8688-
cparams.offload_k = params.offload_k;
8689-
cparams.offload_v = params.offload_v;
8701+
cparams.offload_kqv = params.offload_kqv;
86908702

86918703
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
86928704
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -8724,7 +8736,7 @@ struct llama_context * llama_new_context_with_model(
87248736

87258737
// reserve memory for context buffers
87268738
if (!hparams.vocab_only) {
8727-
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_k, cparams.offload_v)) {
8739+
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) {
87288740
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
87298741
llama_free(ctx);
87308742
return nullptr;

llama.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,11 @@ extern "C" {
192192
uint32_t yarn_orig_ctx; // YaRN original context size
193193

194194
// Keep the booleans together to avoid misalignment during copy-by-value.
195-
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
196-
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
197-
bool logits_all; // the llama_eval() call computes all logits, not just the last one
198-
bool embedding; // embedding mode only
199-
bool offload_k;
200-
bool offload_v;
195+
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
196+
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
197+
bool logits_all; // the llama_eval() call computes all logits, not just the last one
198+
bool embedding; // embedding mode only
199+
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
201200
};
202201

203202
// model quantization parameters

0 commit comments

Comments
 (0)