File tree Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Expand file tree Collapse file tree 3 files changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -64,6 +64,10 @@ llama_pos llama_context::pos_max() const {
64
64
return kv_self.pos_max ();
65
65
}
66
66
67
+ uint32_t llama_context::get_ctx_padding (const llama_cparams & cparams) const {
68
+ return kv_self.get_padding (cparams);
69
+ }
70
+
67
71
// TODO: improve
68
72
void llama_context::reset () {
69
73
inp_tokens = nullptr ;
Original file line number Diff line number Diff line change @@ -84,8 +84,11 @@ struct llama_context {
84
84
ggml_cgraph * graph,
85
85
bool batched);
86
86
87
+ // max token position across all sequences in the current context
87
88
llama_pos pos_max () const ;
88
89
90
+ uint32_t get_ctx_padding (const llama_cparams & cparams) const ;
91
+
89
92
void reset ();
90
93
91
94
void prepare_k_shift ();
Original file line number Diff line number Diff line change @@ -7820,6 +7820,7 @@ static int llama_decode_impl(
7820
7820
}
7821
7821
7822
7822
// temporary allocate memory for the input batch if needed
7823
+ // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
7823
7824
llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.pos_max () + 1 );
7824
7825
7825
7826
const llama_batch & batch = batch_allocr.batch ;
@@ -8154,6 +8155,7 @@ static int llama_encode_impl(
8154
8155
}
8155
8156
8156
8157
// temporary allocate memory for the input batch if needed
8158
+ // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
8157
8159
llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.pos_max () + 1 );
8158
8160
8159
8161
const llama_batch & batch = batch_allocr.batch ;
@@ -8619,7 +8621,7 @@ struct llama_context * llama_init_from_model(
8619
8621
cparams.rope_freq_scale = params.rope_freq_scale == 0 .0f ? hparams.rope_freq_scale_train : params.rope_freq_scale ;
8620
8622
8621
8623
// this is necessary due to kv_self.n being padded later during inference
8622
- cparams.n_ctx = GGML_PAD (cparams.n_ctx , ctx->kv_self . get_padding (cparams));
8624
+ cparams.n_ctx = GGML_PAD (cparams.n_ctx , ctx->get_ctx_padding (cparams));
8623
8625
8624
8626
// with causal attention, the batch size is limited by the context size
8625
8627
cparams.n_batch = hparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
You can’t perform that action at this time.
0 commit comments