@@ -8254,6 +8254,9 @@ struct llm_build_context {
8254
8254
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
8255
8255
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
8256
8256
8257
+ // positions of the tokens in the KV cache
8258
+ struct ggml_tensor * KQ_pos = build_inp_KQ_pos(false);
8259
+
8257
8260
// iterate layers
8258
8261
for (int il = 0; il < n_layer; ++il) {
8259
8262
struct ggml_tensor * cur = inpL;
@@ -8322,7 +8325,7 @@ struct llm_build_context {
8322
8325
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
8323
8326
cb(kq, "kq", il);
8324
8327
8325
- kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr , 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
8328
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, KQ_pos , 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
8326
8329
cb(kq, "kq_soft_max_ext", il);
8327
8330
8328
8331
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
@@ -11523,7 +11526,7 @@ static int llama_decode_internal(
11523
11526
}
11524
11527
11525
11528
// non-causal masks do not use the KV cache
11526
- if (hparams.causal_attn) {
11529
+ if (hparams.causal_attn || model.arch == LLM_ARCH_JINA_BERT_V2 ) {
11527
11530
llama_kv_cache_update(&lctx);
11528
11531
11529
11532
// if we have enough unused cells before the current head ->
0 commit comments