Skip to content

Commit f4ba506

Browse files
committed
context : add get_ctx_padding()
ggml-ci
1 parent 70efeb7 commit f4ba506

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

src/llama-context.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ llama_pos llama_context::pos_max() const {
6464
return kv_self.pos_max();
6565
}
6666

67+
uint32_t llama_context::get_ctx_padding(const llama_cparams & cparams) const {
68+
return kv_self.get_padding(cparams);
69+
}
70+
6771
// TODO: improve
6872
void llama_context::reset() {
6973
inp_tokens = nullptr;

src/llama-context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,11 @@ struct llama_context {
8484
ggml_cgraph * graph,
8585
bool batched);
8686

87+
// max token position across all sequences in the current context
8788
llama_pos pos_max() const;
8889

90+
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
91+
8992
void reset();
9093

9194
void prepare_k_shift();

src/llama.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7820,6 +7820,7 @@ static int llama_decode_impl(
78207820
}
78217821

78227822
// temporary allocate memory for the input batch if needed
7823+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
78237824
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
78247825

78257826
const llama_batch & batch = batch_allocr.batch;
@@ -8154,6 +8155,7 @@ static int llama_encode_impl(
81548155
}
81558156

81568157
// temporary allocate memory for the input batch if needed
8158+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
81578159
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1);
81588160

81598161
const llama_batch & batch = batch_allocr.batch;
@@ -8619,7 +8621,7 @@ struct llama_context * llama_init_from_model(
86198621
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
86208622

86218623
// this is necessary due to kv_self.n being padded later during inference
8622-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->kv_self.get_padding(cparams));
8624+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, ctx->get_ctx_padding(cparams));
86238625

86248626
// with causal attention, the batch size is limited by the context size
86258627
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;

0 commit comments

Comments
 (0)