@@ -2204,11 +2204,7 @@ kernel void kernel_flash_attn_ext_f16(
2204
2204
// pointer to the mask
2205
2205
device const half * mp = (device const half *) (mask + iq1*nb31);
2206
2206
2207
- // prepare diagonal scale matrix
2208
- simdgroup_float8x8 mscale (scale);
2209
-
2210
- // prepare diagonal slope matrix
2211
- simdgroup_float8x8 mslope (1 .0f );
2207
+ float slope = 1 .0f ;
2212
2208
2213
2209
// ALiBi
2214
2210
if (max_bias > 0 .0f ) {
@@ -2217,7 +2213,7 @@ kernel void kernel_flash_attn_ext_f16(
2217
2213
const float base = h < n_head_log2 ? m0 : m1;
2218
2214
const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
2219
2215
2220
- mslope = simdgroup_float8x8 ( pow (base, exph) );
2216
+ slope = pow (base, exph);
2221
2217
}
2222
2218
2223
2219
// loop over the KV cache
@@ -2242,18 +2238,20 @@ kernel void kernel_flash_attn_ext_f16(
2242
2238
simdgroup_multiply_accumulate (mqk, mq[i], mk, mqk);
2243
2239
}
2244
2240
2241
+ simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
2242
+
2243
+ const short tx = tiisg%4 ;
2244
+ const short ty = tiisg/4 ;
2245
+
2245
2246
if (mask != q) {
2246
2247
// mqk = mqk*scale + mask*slope
2247
- simdgroup_half8x8 mm;
2248
- simdgroup_load (mm, mp + ic + 8 *cc, nb31/sizeof (half), 0 , false );
2249
- simdgroup_multiply (mm, mslope, mm);
2250
- simdgroup_multiply_accumulate (mqk, mqk, mscale, mm);
2248
+ ss[8 *cc + ty*TF + 2 *tx + 0 ] = scale*ss[8 *cc + ty*TF + 2 *tx + 0 ] + slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 0 ];
2249
+ ss[8 *cc + ty*TF + 2 *tx + 1 ] = scale*ss[8 *cc + ty*TF + 2 *tx + 1 ] + slope*mp[ic + 8 *cc + ty*nb31/sizeof (half) + 2 *tx + 1 ];
2251
2250
} else {
2252
2251
// mqk = mqk*scale
2253
- simdgroup_multiply (mqk, mscale, mqk);
2252
+ ss[8 *cc + ty*TF + 2 *tx + 0 ] *= scale;
2253
+ ss[8 *cc + ty*TF + 2 *tx + 1 ] *= scale;
2254
2254
}
2255
-
2256
- simdgroup_store (mqk, ss + 8 *cc, TF, 0 , false );
2257
2255
}
2258
2256
}
2259
2257
@@ -2816,8 +2814,7 @@ kernel void kernel_cpy_f32_f16(
2816
2814
for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
2817
2815
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2818
2816
2819
- // TODO: is there a better way to handle -INFINITY?
2820
- dst_data[i00] = src[0 ] == -INFINITY ? -MAXHALF : src[0 ];
2817
+ dst_data[i00] = src[0 ];
2821
2818
}
2822
2819
}
2823
2820
0 commit comments