Skip to content

Commit bf9c101

Browse files
authored
metal : use F32 prec for K*Q in vec FA (#9595)
ggml-ci
1 parent e62e978 commit bf9c101

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ggml/src/ggml-metal.metal

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
26312631
const short iv3 = iq3 / rv3;
26322632

26332633
// load the queries from shared memory into local memory
2634-
half4 mq[D4];
2634+
float4 mq[D4];
26352635

26362636
for (short ii = 0; ii < D4; ii += NW) {
26372637
short i = ii + tiisg;
2638-
mq[i] = sq4[i];
2638+
mq[i] = (float4) sq4[i];
26392639
}
26402640

26412641
// pointer to the mask
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
26612661
for (short ii = 0; ii < D4; ii += NW) {
26622662
const short i = ii + tiisg;
26632663

2664-
half4x4 mk;
2665-
mk[0] = pk4[i + 0*(nb11/8)];
2666-
mk[1] = pk4[i + 1*(nb11/8)];
2667-
mk[2] = pk4[i + 2*(nb11/8)];
2668-
mk[3] = pk4[i + 3*(nb11/8)];
2664+
float4x4 mk;
2665+
mk[0] = (float4) pk4[i + 0*(nb11/8)];
2666+
mk[1] = (float4) pk4[i + 1*(nb11/8)];
2667+
mk[2] = (float4) pk4[i + 2*(nb11/8)];
2668+
mk[3] = (float4) pk4[i + 3*(nb11/8)];
26692669

26702670
mqk += (float4) (mq[i] * mk);
26712671
}

0 commit comments

Comments
 (0)