Skip to content

Commit 614d3b9

Browse files
authored
llama : less KV padding when FA is off (#7257)
ggml-ci
1 parent 30e7033 commit 614d3b9

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

llama.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,6 +2805,11 @@ static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
28052805
cache.do_defrag = true;
28062806
}
28072807

2808+
static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
2809+
// the FA kernels require padding to avoid extra runtime boundary checks
2810+
return cparams.flash_attn ? 256u : 32u;
2811+
}
2812+
28082813
//
28092814
// model loading and saving
28102815
//
@@ -11510,7 +11515,8 @@ static int llama_decode_internal(
1151011515
// a heuristic, to avoid attending the full cache if it is not yet utilized
1151111516
// after enough generations, the benefit from this heuristic disappears
1151211517
// if we start defragmenting the cache, the benefit from this will be more important
11513-
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
11518+
const uint32_t pad = llama_kv_cache_get_padding(cparams);
11519+
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
1151411520
//kv_self.n = llama_kv_cache_cell_max(kv_self);
1151511521
}
1151611522
}
@@ -15511,6 +15517,11 @@ struct llama_context * llama_new_context_with_model(
1551115517
return nullptr;
1551215518
}
1551315519

15520+
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
15521+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15522+
params.flash_attn = false;
15523+
}
15524+
1551415525
llama_context * ctx = new llama_context(*model);
1551515526

1551615527
const auto & hparams = model->hparams;
@@ -15534,7 +15545,7 @@ struct llama_context * llama_new_context_with_model(
1553415545
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
1553515546

1553615547
// this is necessary due to kv_self.n being padded later during inference
15537-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
15548+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
1553815549

1553915550
// with causal attention, the batch size is limited by the context size
1554015551
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -15579,11 +15590,6 @@ struct llama_context * llama_new_context_with_model(
1557915590
}
1558015591
}
1558115592

15582-
if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
15583-
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15584-
cparams.flash_attn = false;
15585-
}
15586-
1558715593
if (params.seed == LLAMA_DEFAULT_SEED) {
1558815594
params.seed = time(NULL);
1558915595
}

0 commit comments

Comments
 (0)