Skip to content

Commit f6c7bf0

Browse files
ggerganovteleprint-me
authored andcommitted
metal : handle F16 inf values, fix FA partial offload (ggml-org#7434)
ggml-ci
1 parent 2200f3e commit f6c7bf0

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

ggml-metal.metal

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2204,11 +2204,7 @@ kernel void kernel_flash_attn_ext_f16(
22042204
// pointer to the mask
22052205
device const half * mp = (device const half *) (mask + iq1*nb31);
22062206

2207-
// prepare diagonal scale matrix
2208-
simdgroup_float8x8 mscale(scale);
2209-
2210-
// prepare diagonal slope matrix
2211-
simdgroup_float8x8 mslope(1.0f);
2207+
float slope = 1.0f;
22122208

22132209
// ALiBi
22142210
if (max_bias > 0.0f) {
@@ -2217,7 +2213,7 @@ kernel void kernel_flash_attn_ext_f16(
22172213
const float base = h < n_head_log2 ? m0 : m1;
22182214
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
22192215

2220-
mslope = simdgroup_float8x8(pow(base, exph));
2216+
slope = pow(base, exph);
22212217
}
22222218

22232219
// loop over the KV cache
@@ -2242,18 +2238,20 @@ kernel void kernel_flash_attn_ext_f16(
22422238
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
22432239
}
22442240

2241+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2242+
2243+
const short tx = tiisg%4;
2244+
const short ty = tiisg/4;
2245+
22452246
if (mask != q) {
22462247
// mqk = mqk*scale + mask*slope
2247-
simdgroup_half8x8 mm;
2248-
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2249-
simdgroup_multiply(mm, mslope, mm);
2250-
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2248+
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2249+
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
22512250
} else {
22522251
// mqk = mqk*scale
2253-
simdgroup_multiply(mqk, mscale, mqk);
2252+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2253+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
22542254
}
2255-
2256-
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
22572255
}
22582256
}
22592257

@@ -2816,8 +2814,7 @@ kernel void kernel_cpy_f32_f16(
28162814
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
28172815
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
28182816

2819-
// TODO: is there a better way to handle -INFINITY?
2820-
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2817+
dst_data[i00] = src[0];
28212818
}
28222819
}
28232820

0 commit comments

Comments
 (0)