Skip to content

Commit d9b8dd6

Browse files
author
Joan Martinez
committed
fix: add some changes as per review
1 parent f8d1709 commit d9b8dd6

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

ggml.c

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5478,9 +5478,9 @@ static struct ggml_tensor * ggml_soft_max_impl(
54785478
GGML_ASSERT(pos->type == mask->type);
54795479
}
54805480

5481-
/*if (max_bias > 0.0f) {
5481+
if (max_bias > 0.0f) {
54825482
GGML_ASSERT(pos);
5483-
}*/
5483+
}
54845484

54855485
bool is_node = false;
54865486

@@ -12401,7 +12401,6 @@ static void ggml_compute_forward_soft_max_f32(
1240112401
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
1240212402

1240312403
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12404-
//float * pos = src2 ? (float *) src2->data : NULL;
1240512404
ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
1240612405
float * pos_f32 = src2 ? (float *) src2->data : src0->data;
1240712406

@@ -12436,13 +12435,13 @@ static void ggml_compute_forward_soft_max_f32(
1243612435

1243712436
if (use_f16) {
1243812437
for (int i = 0; i < nc; ++i) {
12439-
wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
12440-
//wp[i] = wp[i] - slope*abs(i1%nc - i);
12438+
//wp[i] -= slope*GGML_FP16_TO_FP32(pos_f16[i]);
12439+
wp[i] -= slope*abs(i1%nc - i);
1244112440
}
1244212441
} else {
1244312442
for (int i = 0; i < nc; ++i) {
12444-
wp[i] += slope*pos_f32[i];
12445-
//wp[i] = wp[i] - slope*abs(i1%nc - i);
12443+
//wp[i] -= slope*pos_f32[i];
12444+
wp[i] -= slope*abs(i1%nc - i);
1244612445
}
1244712446
}
1244812447
}

llama.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8254,6 +8254,9 @@ struct llm_build_context {
82548254
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
82558255
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
82568256

8257+
// positions of the tokens in the KV cache
8258+
struct ggml_tensor * KQ_pos = build_inp_KQ_pos(false);
8259+
82578260
// iterate layers
82588261
for (int il = 0; il < n_layer; ++il) {
82598262
struct ggml_tensor * cur = inpL;
@@ -8322,7 +8325,7 @@ struct llm_build_context {
83228325
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
83238326
cb(kq, "kq", il);
83248327

8325-
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
8328+
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, KQ_pos, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
83268329
cb(kq, "kq_soft_max_ext", il);
83278330

83288331
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
@@ -11523,7 +11526,7 @@ static int llama_decode_internal(
1152311526
}
1152411527

1152511528
// non-causal masks do not use the KV cache
11526-
if (hparams.causal_attn) {
11529+
if (hparams.causal_attn || model.arch == LLM_ARCH_JINA_BERT_V2) {
1152711530
llama_kv_cache_update(&lctx);
1152811531

1152911532
// if we have enough unused cells before the current head ->

0 commit comments

Comments
 (0)